{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a5d6ad19-adff-423b-8177-30a0d4f6ceed",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import accelerate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "02e86302-68c0-455a-9457-6a4618b95cba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: HF_ENDPOINT=https://hf-mirror.com\n"
     ]
    }
   ],
   "source": [
    "%env HF_ENDPOINT=https://hf-mirror.com\n",
    "import os\n",
    "os.environ['HF_HOME'] = '/data1/ckw'\n",
    "os.environ['HF_ENDPOINT']='https://hf-mirror.com'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a45076d6-d882-4450-8596-424d364ba65e",
   "metadata": {},
   "source": [
    "首先，我们构造Tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4f4f30f9-d442-4079-9f0b-04874e9de217",
   "metadata": {},
   "outputs": [],
   "source": [
    "import regex as re\n",
    "import base64\n",
    "import os\n",
    "import json\n",
    "import tiktoken\n",
    "from torch import TensorType\n",
    "from typing import List, Optional, Union, Dict, Any\n",
    "from transformers import PreTrainedTokenizer\n",
    "from transformers.utils import logging, PaddingStrategy\n",
    "from transformers.tokenization_utils_base import EncodedInput, BatchEncoding"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff7537fb-dcc2-47fc-a4da-a33e8a68f1ce",
   "metadata": {},
   "source": [
    "`ChatGLM4Tokenizer` 类是一个自定义的 `PreTrainedTokenizer` 类，用于处理特定的 GLM-4 模型的分词需求。这个类主要有以下结构和成员函数：\n",
    "\n",
    "### 类结构\n",
    "- **属性**：\n",
    "  - `vocab_files_names`：包含词汇文件名的字典。\n",
    "  - `model_input_names`：包含模型输入名的列表。\n",
    "  - `vocab_file`：词汇文件路径。\n",
    "  - `name`：分词器的名称。\n",
    "  - `pat_str`：正则表达式字符串，用于分词。\n",
    "  - `encode_special_tokens`：是否编码特殊字符。\n",
    "  - `mergeable_ranks`：可合并的词汇排名。\n",
    "  - `tokenizer`：基于 `tiktoken` 库的编码器。\n",
    "  - `decoder`：解码器，映射词汇排名到词汇。\n",
    "  - `n_words`：词汇表大小。\n",
    "\n",
    "### 成员函数\n",
    "\n",
    "- **初始化函数**：\n",
    "  ```python\n",
    "  def __init__(self, vocab_file, padding_side=\"left\", clean_up_tokenization_spaces=False, encode_special_tokens=False, **kwargs)\n",
    "  ```\n",
    "  初始化 `ChatGLM4Tokenizer` 类，加载词汇文件，设置正则表达式和分词器。\n",
    "\n",
    "- **词汇表大小属性**：\n",
    "  ```python\n",
    "  @property\n",
    "  def vocab_size(self)\n",
    "  ```\n",
    "  返回词汇表大小。\n",
    "\n",
    "- **获取词汇表**：\n",
    "  ```python\n",
    "  def get_vocab(self)\n",
    "  ```\n",
    "  返回词汇表字典。\n",
    "\n",
    "- **将 tokens 转换为字符串**：\n",
    "  ```python\n",
    "  def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str\n",
    "  ```\n",
    "  将 tokens 序列转换为字符串。\n",
    "\n",
    "- **分词**：\n",
    "  ```python\n",
    "  def _tokenize(self, text, **kwargs)\n",
    "  ```\n",
    "  将文本分词为 token 列表。\n",
    "\n",
    "- **将 token 转换为 ID**：\n",
    "  ```python\n",
    "  def _convert_token_to_id(self, token)\n",
    "  ```\n",
    "  将 token（字符串）转换为 ID。\n",
    "\n",
    "- **将 ID 转换为 token**：\n",
    "  ```python\n",
    "  def _convert_id_to_token(self, index)\n",
    "  ```\n",
    "  将 ID（整数）转换为 token（字符串）。\n",
    "\n",
    "- **保存词汇表**：\n",
    "  ```python\n",
    "  def save_vocabulary(self, save_directory, filename_prefix=None)\n",
    "  ```\n",
    "  保存词汇表到指定目录。\n",
    "\n",
    "- **获取前缀 tokens**：\n",
    "  ```python\n",
    "  def get_prefix_tokens(self)\n",
    "  ```\n",
    "  返回前缀 tokens 列表。\n",
    "\n",
    "- **构建单条消息**：\n",
    "  ```python\n",
    "  def build_single_message(self, role, metadata, message, tokenize=True)\n",
    "  ```\n",
    "  构建单条消息，包含角色、元数据和消息内容。\n",
    "\n",
    "- **应用聊天模板**：\n",
    "  ```python\n",
    "  def apply_chat_template(self, conversation, add_generation_prompt=False, tokenize=True, padding=False, truncation=False, max_length=None, return_tensors=None, return_dict=False, tokenizer_kwargs=None, add_special_tokens=True, **kwargs)\n",
    "  ```\n",
    "  将会话数据应用到聊天模板中。\n",
    "\n",
    "- **构建带有特殊 token 的输入**：\n",
    "  ```python\n",
    "  def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None)\n",
    "  ```\n",
    "  构建带有特殊 token 的输入。\n",
    "\n",
    "- **填充输入**：\n",
    "  ```python\n",
    "  def _pad(self, encoded_inputs, max_length=None, padding_strategy=PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of=None, return_attention_mask=None)\n",
    "  ```\n",
    "  填充编码后的输入，确保输入长度一致。\n",
    "\n",
    "### 类中的主要函数和方法概述\n",
    "\n",
    "- **初始化**：`__init__` 函数用于初始化分词器，包括加载词汇表、正则表达式的编译和编码器的初始化。\n",
    "- **获取词汇表和大小**：`get_vocab` 和 `vocab_size` 属性用于获取词汇表及其大小。\n",
    "- **分词及转换**：`_tokenize`、`convert_tokens_to_string`、`_convert_token_to_id` 和 `_convert_id_to_token` 函数用于实现文本的分词及 tokens 和 IDs 之间的转换。\n",
    "- **保存和加载**：`save_vocabulary` 函数用于保存词汇表到指定目录。\n",
    "- **会话处理**：`get_prefix_tokens`、`build_single_message` 和 `apply_chat_template` 函数用于处理会话数据和模板应用。\n",
    "- **输入处理**：`build_inputs_with_special_tokens` 和 `_pad` 函数用于处理模型输入，确保输入格式和长度符合要求。\n",
    "\n",
    "通过以上结构和成员函数，`ChatGLM4Tokenizer` 类实现了一个完整的分词器功能，能够处理特定的 GLM-4 模型的分词需求。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "4a119f91-e04b-4f8b-aa2a-02ca20a88a0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ChatGLM4Tokenizer(PreTrainedTokenizer):\n",
    "    # 定义词汇文件名和模型输入名\n",
    "    vocab_files_names = {\"vocab_file\": \"tokenizer.model\"}\n",
    "    model_input_names = [\"input_ids\", \"attention_mask\", \"position_ids\"]\n",
    "\n",
    "    def __init__(\n",
    "            self,\n",
    "            vocab_file,\n",
    "            padding_side=\"left\",\n",
    "            clean_up_tokenization_spaces=False,\n",
    "            encode_special_tokens=False,\n",
    "            **kwargs\n",
    "    ):\n",
    "        # 初始化一些基础属性\n",
    "        self.name = \"GLM4Tokenizer\"\n",
    "        self.vocab_file = vocab_file\n",
    "        \n",
    "        # 正则表达式模式字符串，用于分词\n",
    "        pat_str = \"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\\\r\\\\n\\\\p{L}\\\\p{N}]?\\\\p{L}+|\\\\p{N}{1,3}| ?[^\\\\s\\\\p{L}\\\\p{N}]+[\\\\r\\\\n]*|\\\\s*[\\\\r\\\\n]+|\\\\s+(?!\\\\S)|\\\\s+\"\n",
    "        # 编译正则表达式\n",
    "        self.pat_str = regex.compile(pat_str)\n",
    "        # self.pat_str = re.compile(pat_str)\n",
    "        # 是否编码特殊字符\n",
    "        self.encode_special_tokens = encode_special_tokens\n",
    "\n",
    "        # 用于存储可合并的词汇排名\n",
    "        mergeable_ranks = {}\n",
    "        # 读取词汇文件\n",
    "        with open(vocab_file) as f:\n",
    "            for line in f:\n",
    "                token, rank = line.strip().split()  # 读取每一行，获取词汇和其对应的排名\n",
    "                rank = int(rank)  # 将排名转换为整数\n",
    "                token = base64.b64decode(token)  # 解码词汇\n",
    "                mergeable_ranks[token] = rank  # 存储到词汇排名字典中\n",
    "\n",
    "        self.mergeable_ranks = mergeable_ranks\n",
    "\n",
    "        # 初始化编码器，使用 tiktoken 库\n",
    "        self.tokenizer = tiktoken.Encoding(\n",
    "            name=\"my_tokenizer\",\n",
    "            pat_str=pat_str,\n",
    "            mergeable_ranks=mergeable_ranks,\n",
    "            special_tokens={}\n",
    "        )\n",
    "        # 解码器，映射排名到词汇\n",
    "        self.decoder = {rank: token for token, rank in mergeable_ranks.items()}\n",
    "        # 词汇表大小\n",
    "        self.n_words = len(self.decoder)\n",
    "\n",
    "        # 调用父类的初始化方法\n",
    "        super().__init__(\n",
    "            padding_side=padding_side,\n",
    "            clean_up_tokenization_spaces=clean_up_tokenization_spaces,\n",
    "            **kwargs\n",
    "        )\n",
    "\n",
    "    @property\n",
    "    def vocab_size(self):\n",
    "        # 返回词汇表大小\n",
    "        return self.n_words\n",
    "\n",
    "    def get_vocab(self):\n",
    "        \"\"\" 返回词汇表字典 \"\"\"\n",
    "        vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}\n",
    "        vocab.update(self.added_tokens_encoder)\n",
    "        return vocab\n",
    "\n",
    "    def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:\n",
    "        \"\"\"\n",
    "        将 tokens 序列转换为字符串。\n",
    "        \"\"\"\n",
    "        text = \"\"\n",
    "        temp = b\"\"\n",
    "        for t in tokens:\n",
    "            if isinstance(t, str):\n",
    "                if temp:\n",
    "                    text += temp.decode(\"utf-8\", errors=\"replace\")\n",
    "                    temp = b\"\"\n",
    "                text += t\n",
    "            elif isinstance(t, bytes):\n",
    "                temp += t\n",
    "            else:\n",
    "                raise TypeError(\"token should only be of type bytes or str\")\n",
    "        if temp:\n",
    "            text += temp.decode(\"utf-8\", errors=\"replace\")\n",
    "        return text\n",
    "\n",
    "    def _tokenize(self, text, **kwargs):\n",
    "        # 使用正则表达式和编码器对文本进行分词\n",
    "        tokens = []\n",
    "        ids = self.tokenizer.encode(text)\n",
    "        for t in ids:\n",
    "            tokens.append(self.decoder[t])\n",
    "        return tokens\n",
    "\n",
    "    def _convert_token_to_id(self, token):\n",
    "        \"\"\" 将 token (字符串) 转换为 id (整数) \"\"\"\n",
    "        return self.mergeable_ranks[token]\n",
    "\n",
    "    def _convert_id_to_token(self, index):\n",
    "        \"\"\" 将 id (整数) 转换为 token (字符串) \"\"\"\n",
    "        return self.decoder.get(index, \"\")\n",
    "\n",
    "    def save_vocabulary(self, save_directory, filename_prefix=None):\n",
    "        \"\"\"\n",
    "        将词汇表和特殊字符文件保存到指定目录。\n",
    "\n",
    "        Args:\n",
    "            save_directory (`str`): 保存词汇表的目录。\n",
    "            filename_prefix (`str`, *optional*): 保存文件名的前缀。\n",
    "\n",
    "        Returns:\n",
    "            `Tuple(str)`: 保存的文件路径。\n",
    "        \"\"\"\n",
    "        if os.path.isdir(save_directory):\n",
    "            vocab_file = os.path.join(\n",
    "                save_directory, self.vocab_files_names[\"vocab_file\"]\n",
    "            )\n",
    "        else:\n",
    "            vocab_file = save_directory\n",
    "\n",
    "        with open(self.vocab_file, 'rb') as fin:\n",
    "            proto_str = fin.read()\n",
    "\n",
    "        with open(vocab_file, \"wb\") as writer:\n",
    "            writer.write(proto_str)\n",
    "\n",
    "        return (vocab_file,)\n",
    "\n",
    "    def get_prefix_tokens(self):\n",
    "        # 返回前缀 tokens 列表\n",
    "        prefix_tokens = [self.convert_tokens_to_ids(\"[gMASK]\"), self.convert_tokens_to_ids(\"<sop>\")]\n",
    "        return prefix_tokens\n",
    "\n",
    "    def build_single_message(self, role, metadata, message, tokenize=True):\n",
    "        # 构建单条消息，包含角色、元数据和消息内容\n",
    "        assert role in [\"system\", \"user\", \"assistant\", \"observation\"], role\n",
    "        if tokenize:\n",
    "            role_tokens = [self.convert_tokens_to_ids(f\"<|{role}|>\")] + self.tokenizer.encode(f\"{metadata}\\n\",\n",
    "                                                                                              disallowed_special=())\n",
    "            message_tokens = self.tokenizer.encode(message, disallowed_special=())\n",
    "            tokens = role_tokens + message_tokens\n",
    "            return tokens\n",
    "        else:\n",
    "            return str(f\"<|{role}|>{metadata}\\n{message}\")\n",
    "\n",
    "    def apply_chat_template(\n",
    "            self,\n",
    "            conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], \"Conversation\"],\n",
    "            add_generation_prompt: bool = False,\n",
    "            tokenize: bool = True,\n",
    "            padding: bool = False,\n",
    "            truncation: bool = False,\n",
    "            max_length: Optional[int] = None,\n",
    "            return_tensors: Optional[Union[str, TensorType]] = None,\n",
    "            return_dict: bool = False,\n",
    "            tokenizer_kwargs: Optional[Dict[str, Any]] = None,\n",
    "            add_special_tokens: bool = True,\n",
    "            **kwargs,\n",
    "    ) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:\n",
    "    \n",
    "        if return_dict and not tokenize:\n",
    "            raise ValueError(\n",
    "                \"`return_dict=True` is incompatible with `tokenize=False`, because there is no dict \"\n",
    "                \"of tokenizer outputs to return.\"\n",
    "            )\n",
    "    \n",
    "        def handle_single_conversation(conversation):\n",
    "            input_ids = self.get_prefix_tokens() if add_special_tokens else []\n",
    "            input_message = \"[gMASK]<sop>\" if add_special_tokens else \"\"\n",
    "            for item in conversation:\n",
    "                if item.get(\"tools\"):\n",
    "                    tools = item[\"tools\"]\n",
    "                    content = \"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的，你的任务是针对用户的问题和要求提供适当的答复和支持。\"\n",
    "                    for tool in tools:\n",
    "                        if tool[\"type\"] == \"function\":\n",
    "                            function = tool[\"function\"]\n",
    "                            content += f\"\\n\\n## {function['name']}\\n\\n{json.dumps(function, ensure_ascii=False, indent=4)}\"\n",
    "                            content += \"\\n在调用上述函数时，请使用 Json 格式表示调用的参数。\"\n",
    "                        elif tool[\"type\"] == \"python\":\n",
    "                            content += \"\\n\\n## python\\n\\n当你向 `python` 发送包含 Python 代码的消息时，该代码将会在一个有状态的 Jupyter notebook 环境中执行。\\n`python` 返回代码执行的输出，或在执行 60 秒后返回超时。\\n`/mnt/data` 将会持久化存储你的文件。在此会话中，`python` 无法访问互联网。不要使用 `python` 进行任何网络请求或者在线 API 调用，这些在线内容的访问将不会成功。\"\n",
    "                        elif tool[\"type\"] == \"simple_browser\":\n",
    "                            content += \"\\n\\n## simple_browser\\n\\n你可以使用 `simple_browser` 工具。该工具支持以下函数：\\n`search(query: str, recency_days: int)`：使用搜索引擎进行查询并显示结果，可以使用 `recency_days` 参数控制搜索内容的时效性。\\n`mclick(ids: list[int])`：获取一系列指定 id 的页面内容。每次调用时，须选择3-10个页面。选择多个角度的页面，同时尽可能选择可信任的信息来源。考虑到部分页面是无法加载的，你也可以多打开一些可能有用的页面而不用担心内容过多。\\n`open_url(url: str)`：打开指定的 URL。\\n\\n使用 `【{引用 id}†{引用文本}】` 来引用内容。\\n\\n操作步骤：1. 使用 `search` 来获得信息列表; 2. 使用 `mclick` 来获取指定 ID 页面的内容; 3. 根据获得的内容进行回复。在回复中应当引用信息来源。\\n 如果用户提供了 URL，也可以用 `open_url` 直接打开页面。\\n如果初次搜索结果没有找到合适的信息，也可以再次使用 `search` 进行搜索。\"\n",
    "                        elif tool[\"type\"] == \"cogview\":\n",
    "                            content += \"\\n\\n## cogview\\n\\n如果用户的请求中包含了对图像的描述，你可以使用 `cogview` 来生成图像并展示给用户。你需要向 `cogview` 发送图像描述，规则：\\n- 发送给 `cogview` 的消息必须使用英语。用户的中文描述必须完全翻译为英语。\\n- 应当尽可能详细地描述图像生成的需求，需求描述约 100 英文单词。\\n- 保持用户原始描述的意图。不要虚构内容或者没见过的人物。\\n- 如无特殊说明，所在地为中国，持有中国立场并遵循中国社会主义价值观。\"\n",
    "                        else:\n",
    "                            raise NotImplementedError(f\"Unknown tool type {tool['type']}\")\n",
    "                    input = self.build_single_message(\"system\", \"\", content, tokenize=tokenize)\n",
    "                    if tokenize:\n",
    "                        input_ids.extend(input)\n",
    "                    else:\n",
    "                        input_message += input\n",
    "                if item[\"content\"]:\n",
    "                    input = self.build_single_message(\n",
    "                        item[\"role\"],\n",
    "                        item.get(\"metadata\", \"\"),\n",
    "                        item[\"content\"],\n",
    "                        tokenize=tokenize\n",
    "                    )\n",
    "                    if tokenize:\n",
    "                        input_ids.extend(input)\n",
    "                    else:\n",
    "                        input_message += input\n",
    "            if add_generation_prompt:\n",
    "                if tokenize:\n",
    "                    input_ids.extend([self.convert_tokens_to_ids(\"<|assistant|>\")])\n",
    "                else:\n",
    "                    input_message += \"<|assistant|>\"\n",
    "                # if tokenize:\n",
    "                #     input_ids.extend([self.convert_tokens_to_ids(\"[gMASK]\")])  # 使用特殊标记代替空字符串\n",
    "                # else:\n",
    "                #     input_message += \"[gMASK]\"\n",
    "    \n",
    "            return input_ids if tokenize else input_message\n",
    "    \n",
    "        # 处理不同会话格式的主逻辑\n",
    "        if isinstance(conversation, list) and all(isinstance(i, dict) for i in conversation):\n",
    "            result = handle_single_conversation(conversation)\n",
    "        elif isinstance(conversation, list) and all(isinstance(i, list) for i in conversation):\n",
    "            result = [handle_single_conversation(c) for c in conversation]\n",
    "        elif hasattr(conversation, \"messages\"):\n",
    "            result = handle_single_conversation(conversation.messages)\n",
    "        else:\n",
    "            raise ValueError(\"Invalid conversation format\")\n",
    "    \n",
    "        if tokenize:\n",
    "            output = self.batch_encode_plus(\n",
    "                [result] if isinstance(result[0], int) else result,\n",
    "                padding=padding,\n",
    "                truncation=truncation,\n",
    "                max_length=max_length,\n",
    "                return_tensors=return_tensors,\n",
    "                is_split_into_words=True,\n",
    "                add_special_tokens=False\n",
    "            )\n",
    "            if return_dict:\n",
    "                return output\n",
    "            else:\n",
    "                return output[\"input_ids\"]\n",
    "        else:\n",
    "            return result\n",
    "\n",
    "\n",
    "\n",
    "    def build_inputs_with_special_tokens(\n",
    "            self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n",
    "    ) -> List[int]:\n",
    "        \"\"\"\n",
    "        构建模型输入，适用于序列分类任务，通过连接和添加特殊 tokens。\n",
    "        BERT 序列格式：\n",
    "        - 单序列: `[CLS] X [SEP]`\n",
    "        - 序列对: `[CLS] A [SEP] B [SEP]`\n",
    "\n",
    "        Args:\n",
    "            token_ids_0 (`List[int]`): 要添加特殊 tokens 的 ID 列表。\n",
    "            token_ids_1 (`List[int]`, *optional*): 可选的第二个 ID 列表，表示序列对。\n",
    "\n",
    "        Returns:\n",
    "            `List[int]`: 添加了适当特殊 tokens 的输入 ID 列表。\n",
    "        \"\"\"\n",
    "        prefix_tokens = self.get_prefix_tokens()\n",
    "        token_ids_0 = prefix_tokens + token_ids_0\n",
    "        if token_ids_1 is not None:\n",
    "            token_ids_0 = token_ids_0 + token_ids_1 + [self.convert_tokens_to_ids(\"<eos>\")]\n",
    "        return token_ids_0\n",
    "\n",
    "    def _pad(\n",
    "            self,\n",
    "            encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],\n",
    "            max_length: Optional[int] = None,\n",
    "            padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,\n",
    "            pad_to_multiple_of: Optional[int] = None,\n",
    "            return_attention_mask: Optional[bool] = None,\n",
    "    ) -> dict:\n",
    "        \"\"\"\n",
    "        填充编码后的输入（左右填充，并预定义长度或批处理中的最大长度）\n",
    "\n",
    "        Args:\n",
    "            encoded_inputs: 包含编码后的输入 (`List[int]`) 或批处理的输入 (`List[List[int]]`) 的字典。\n",
    "            max_length: 返回列表的最大长度，并可选的填充长度。\n",
    "            padding_strategy: 用于填充的策略。\n",
    "                - PaddingStrategy.LONGEST: 填充到批处理中的最长序列\n",
    "                - PaddingStrategy.MAX_LENGTH: 填充到最大长度（默认）\n",
    "                - PaddingStrategy.DO_NOT_PAD: 不填充\n",
    "                填充策略由 self.padding_side 定义：\n",
    "                    - 'left': 在序列左侧填充\n",
    "                    - 'right': 在序列右侧填充\n",
    "            pad_to_multiple_of: (可选) 整数，如果设置将填充序列为该值的倍数。\n",
    "            return_attention_mask: (可选) 设置为 False 以避免返回注意力掩码（默认：根据模型具体情况设置）\n",
    "        \"\"\"\n",
    "        # 确保填充侧为 'left'\n",
    "        assert self.padding_side == \"left\"\n",
    "\n",
    "        required_input = encoded_inputs[self.model_input_names[0]]\n",
    "        seq_length = len(required_input)\n",
    "\n",
    "        if padding_strategy == PaddingStrategy.LONGEST:\n",
    "            max_length = len(required_input)\n",
    "\n",
    "        if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):\n",
    "            max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of\n",
    "\n",
    "        needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length\n",
    "\n",
    "        # 如果没有注意力掩码，初始化\n",
    "        if \"attention_mask\" not in encoded_inputs:\n",
    "            encoded_inputs[\"attention_mask\"] = [1] * seq_length\n",
    "\n",
    "        if \"position_ids\" not in encoded_inputs:\n",
    "            encoded_inputs[\"position_ids\"] = list(range(seq_length))\n",
    "\n",
    "        if needs_to_be_padded:\n",
    "            difference = max_length - len(required_input)\n",
    "\n",
    "            if \"attention_mask\" in encoded_inputs:\n",
    "                encoded_inputs[\"attention_mask\"] = [0] * difference + encoded_inputs[\"attention_mask\"]\n",
    "            if \"position_ids\" in encoded_inputs:\n",
    "                encoded_inputs[\"position_ids\"] = [0] * difference + encoded_inputs[\"position_ids\"]\n",
    "            encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input\n",
    "\n",
    "        return encoded_inputs\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bd7c0c2-83eb-4f01-a9c6-0b00ef9d05b2",
   "metadata": {},
   "source": [
    "### 关键点解释\n",
    "\n",
    "1. **`mergeable_ranks`**:\n",
    "   - 存储可合并词汇的排名信息。读取词汇文件时，每个词汇都有一个对应的排名（整数），这些词汇被解码为原始字符串，然后存储在 `mergeable_ranks` 字典中。\n",
    "\n",
    "2. **`pat_str`**:\n",
    "   - 正则表达式字符串，用于定义分词模式。这段正则表达式主要用于匹配常见的英语缩写、单词、数字和其他非字母数字字符。编译后的正则表达式存储在 `self.pat_str` 中，用于分词器。\n",
    "\n",
    "3. **`self.tokenizer`**:\n",
    "   - 使用 `tiktoken` 库创建的编码器。通过提供的正则表达式模式和可合并的词汇排名，初始化一个自定义的分词器。\n",
    "\n",
    "4. **`self.decoder`**:\n",
    "   - 解码器，映射词汇排名到词汇。使用 `mergeable_ranks` 字典创建的逆向字典，用于从 ID 转换回原始词汇。\n",
    "\n",
    "5. **`self.n_words`**:\n",
    "   - 词汇表大小，即词汇数量。通过计算 `self.decoder` 的长度获得。\n",
    "\n",
    "6. **初始化过程**：\n",
    "   - 读取词汇文件，解析每一行以获取词汇及其排名，然后解码词汇并存储到 `mergeable_ranks` 字典中。接着，使用 `tiktoken` 库初始化编码器，并创建解码器和词汇表大小。\n",
    "\n",
    "其中有一点值得关注：\n",
    "```python\n",
    "    def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:\n",
    "        \"\"\"\n",
    "        将 tokens 序列转换为字符串。\n",
    "        \"\"\"\n",
    "        text = \"\"\n",
    "        temp = b\"\"\n",
    "        for t in tokens:\n",
    "            if isinstance(t, str):\n",
    "                if temp:\n",
    "                    text += temp.decode(\"utf-8\", errors=\"replace\")\n",
    "                    temp = b\"\"\n",
    "                text += t\n",
    "            elif isinstance(t, bytes):\n",
    "                temp += t\n",
    "            else:\n",
    "                raise TypeError(\"token should only be of type bytes or str\")\n",
    "        if temp:\n",
    "            text += temp.decode(\"utf-8\", errors=\"replace\")\n",
    "        return text\n",
    "```\n",
    "\n",
    "这个函数之所以要将 `token` 解码为 `bytes` 是因为在实际应用中，token 可能是字符串或者字节序列的一部分。在某些情况下，词汇表中的 token 可能以 `bytes` 的形式存储，而不是直接存储为字符串。以下是一些详细原因和场景：\n",
    "\n",
    "1. **支持多种编码格式**：\n",
    "   - 分词器可能处理多种数据源，其中一些数据源可能以 `bytes` 格式存储 token。例如，当数据被压缩、加密或使用特定编码时，token 可能会以字节形式表示。\n",
    "\n",
    "2. **数据处理的灵活性**：\n",
    "   - 通过支持 `bytes` 和 `str` 两种类型，分词器可以更灵活地处理不同来源和格式的输入数据。这在处理包含二进制数据的文本时特别有用，例如一些特殊的标记或字符。\n",
    "\n",
    "3. **确保数据一致性**：\n",
    "   - 在分词和解码过程中，可能会遇到混合类型的 token（即既有字符串也有字节序列）。为了确保数据一致性并正确拼接字符串，函数需要处理 `bytes` 和 `str` 两种类型。\n",
    "\n",
    "4. **兼容性考虑**：\n",
    "   - 某些 NLP 工具和库在处理文本时可能会返回字节形式的 token。为了与这些工具和库兼容，分词器需要能够处理和转换这些字节 token。\n",
    "\n",
    "具体来看这个函数的工作流程：\n",
    "\n",
    "1. **初始化空字符串和字节序列**：\n",
    "   - `text = \"\"` 初始化一个空字符串，用于存储最终的结果。\n",
    "   - `temp = b\"\"` 初始化一个空字节序列，用于临时存储字节 token。\n",
    "\n",
    "2. **遍历 token 列表**：\n",
    "   - 对于每个 token，检查其类型。\n",
    "     - 如果 token 是字符串类型（`str`），检查 `temp` 是否为空。如果 `temp` 非空，将 `temp` 解码为字符串并添加到 `text`，然后清空 `temp`。之后，将当前字符串 token 添加到 `text`。\n",
    "     - 如果 token 是字节类型（`bytes`），将其添加到 `temp`，以便后续解码。\n",
    "     - 如果 token 既不是字符串也不是字节类型，抛出类型错误。\n",
    "\n",
    "3. **处理剩余的字节序列**：\n",
    "   - 在循环结束后，如果 `temp` 非空，将其解码为字符串并添加到 `text`。\n",
    "\n",
    "通过上述步骤，函数能够正确处理混合类型的 token 列表，并将其转换为一个完整的字符串。以下是代码中的注释，以更好地解释这些步骤：\n",
    "\n",
    "```python\n",
    "def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:\n",
    "    \"\"\"\n",
    "    将 tokens 序列转换为字符串。\n",
    "    \"\"\"\n",
    "    text = \"\"  # 初始化一个空字符串用于存储结果\n",
    "    temp = b\"\"  # 初始化一个空字节序列用于临时存储字节 token\n",
    "    for t in tokens:\n",
    "        if isinstance(t, str):\n",
    "            if temp:\n",
    "                # 如果 temp 非空，将其解码为字符串并添加到 text 中\n",
    "                text += temp.decode(\"utf-8\", errors=\"replace\")\n",
    "                temp = b\"\"  # 清空 temp\n",
    "            text += t  # 将字符串 token 添加到 text 中\n",
    "        elif isinstance(t, bytes):\n",
    "            temp += t  # 将字节 token 添加到 temp 中\n",
    "        else:\n",
    "            raise TypeError(\"token should only be of type bytes or str\")  # 抛出类型错误\n",
    "    if temp:\n",
    "        # 处理剩余的字节序列，将其解码为字符串并添加到 text 中\n",
    "        text += temp.decode(\"utf-8\", errors=\"replace\")\n",
    "    return text  # 返回最终的字符串\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c439bb2-5250-4891-973d-f973bd14aae3",
   "metadata": {},
   "source": [
    "接下来设置config，简单设置相关属性即可。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "ac8bdb0a-b279-4365-b6bb-224b0fe43d99",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import PretrainedConfig\n",
    "\n",
    "\n",
    "class ChatGLMConfig(PretrainedConfig):\n",
    "    model_type = \"chatglm\"\n",
    "\n",
    "    def __init__(\n",
    "            self,\n",
    "            num_layers=28,\n",
    "            padded_vocab_size=65024,\n",
    "            hidden_size=4096,\n",
    "            ffn_hidden_size=13696,\n",
    "            kv_channels=128,\n",
    "            num_attention_heads=32,\n",
    "            seq_length=2048,\n",
    "            hidden_dropout=0.0,\n",
    "            classifier_dropout=None,\n",
    "            attention_dropout=0.0,\n",
    "            layernorm_epsilon=1e-5,\n",
    "            rmsnorm=True,\n",
    "            apply_residual_connection_post_layernorm=False,\n",
    "            post_layer_norm=True,\n",
    "            add_bias_linear=False,\n",
    "            add_qkv_bias=False,\n",
    "            bias_dropout_fusion=True,\n",
    "            multi_query_attention=False,\n",
    "            multi_query_group_num=1,\n",
    "            rope_ratio=1,\n",
    "            apply_query_key_layer_scaling=True,\n",
    "            attention_softmax_in_fp32=True,\n",
    "            fp32_residual_connection=False,\n",
    "            **kwargs\n",
    "    ):\n",
    "        self.num_layers = num_layers\n",
    "        self.vocab_size = padded_vocab_size\n",
    "        self.padded_vocab_size = padded_vocab_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.ffn_hidden_size = ffn_hidden_size\n",
    "        self.kv_channels = kv_channels\n",
    "        self.num_attention_heads = num_attention_heads\n",
    "        self.seq_length = seq_length\n",
    "        self.hidden_dropout = hidden_dropout\n",
    "        self.classifier_dropout = classifier_dropout\n",
    "        self.attention_dropout = attention_dropout\n",
    "        self.layernorm_epsilon = layernorm_epsilon\n",
    "        self.rmsnorm = rmsnorm\n",
    "        self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm\n",
    "        self.post_layer_norm = post_layer_norm\n",
    "        self.add_bias_linear = add_bias_linear\n",
    "        self.add_qkv_bias = add_qkv_bias\n",
    "        self.bias_dropout_fusion = bias_dropout_fusion\n",
    "        self.multi_query_attention = multi_query_attention\n",
    "        self.multi_query_group_num = multi_query_group_num\n",
    "        self.rope_ratio = rope_ratio\n",
    "        self.apply_query_key_layer_scaling = apply_query_key_layer_scaling\n",
    "        self.attention_softmax_in_fp32 = attention_softmax_in_fp32\n",
    "        self.fp32_residual_connection = fp32_residual_connection\n",
    "        super().__init__(**kwargs)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7936d603-79a0-41ae-8777-185c5562e0bb",
   "metadata": {},
   "source": [
    "然后就可以开始构建我们的模型了。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "851a39f0-1cad-48c2-a0ce-b781ed98c07f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import math\n",
    "import copy\n",
    "import warnings\n",
    "import re\n",
    "import sys\n",
    "\n",
    "import torch\n",
    "import torch.utils.checkpoint\n",
    "import torch.nn.functional as F\n",
    "from torch import nn\n",
    "from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss\n",
    "from torch.nn.utils import skip_init\n",
    "from typing import Optional, Tuple, Union, List, Callable, Dict, Any\n",
    "from copy import deepcopy\n",
    "\n",
    "from transformers.modeling_outputs import (\n",
    "    BaseModelOutputWithPast,\n",
    "    CausalLMOutputWithPast,\n",
    "    SequenceClassifierOutputWithPast,\n",
    ")\n",
    "from transformers.modeling_utils import PreTrainedModel\n",
    "from transformers.utils import logging\n",
    "from transformers.generation.logits_process import LogitsProcessor\n",
    "from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "8ff61078-3447-4b2d-a77a-034000ce38d2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'linux'"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sys.platform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "eaa943a3-5a6a-4f5a-9d1c-a1d2619e090e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 如果系统平台不是 Darwin (即 macOS或 linux)，则进行以下设置以提高性能\n",
    "if sys.platform != 'darwin':\n",
    "    # 关闭 JIT 的 profiling 模式\n",
    "    torch._C._jit_set_profiling_mode(False)\n",
    "    # 关闭 JIT 的 profiling 执行器\n",
    "    torch._C._jit_set_profiling_executor(False)\n",
    "    # 允许在 CPU 上进行张量融合\n",
    "    torch._C._jit_override_can_fuse_on_cpu(True)\n",
    "    # 允许在 GPU 上进行张量融合\n",
    "    torch._C._jit_override_can_fuse_on_gpu(True)\n",
    "\n",
    "# 获取日志记录器\n",
    "logger = logging.get_logger(__name__)\n",
    "\n",
    "# 用于文档的检查点\n",
    "_CHECKPOINT_FOR_DOC = \"THUDM/ChatGLM\"\n",
    "# 用于文档的配置\n",
    "_CONFIG_FOR_DOC = \"ChatGLMConfig\"\n",
    "\n",
    "# 默认初始化函数\n",
    "def default_init(cls, *args, **kwargs):\n",
    "    return cls(*args, **kwargs)\n",
    "\n",
    "# 无效分数 Logits 处理器类\n",
    "class InvalidScoreLogitsProcessor(LogitsProcessor):\n",
    "    # 处理输入的 logits\n",
    "    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:\n",
    "        # 如果 logits 中存在 NaN 或无穷大，则将其重置\n",
    "        if torch.isnan(scores).any() or torch.isinf(scores).any():\n",
    "            scores.zero_()  # 将所有分数重置为 0\n",
    "            scores[..., 198] = 5e4  # 将特定位置的分数设置为一个大值\n",
    "        return scores\n",
    "\n",
    "# 沿着最后一个维度分割张量的函数\n",
    "def split_tensor_along_last_dim(\n",
    "        tensor: torch.Tensor,\n",
    "        num_partitions: int,\n",
    "        contiguous_split_chunks: bool = False,\n",
    ") -> List[torch.Tensor]:\n",
    "    \"\"\"\n",
    "    沿着最后一个维度分割张量。\n",
    "\n",
    "    参数:\n",
    "        tensor: 输入张量。\n",
    "        num_partitions: 要分割的部分数量。\n",
    "        contiguous_split_chunks: 如果为 True，则使每个分块在内存中是连续的。\n",
    "\n",
    "    返回:\n",
    "        张量的列表\n",
    "    \"\"\"\n",
    "    # 获取最后一个维度的大小\n",
    "    last_dim = tensor.dim() - 1\n",
    "    last_dim_size = tensor.size()[last_dim] // num_partitions\n",
    "    # 分割张量\n",
    "    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)\n",
    "    # 注意: torch.split 默认不会创建连续的张量\n",
    "    if contiguous_split_chunks:\n",
    "        return tuple(chunk.contiguous() for chunk in tensor_list)\n",
    "    return tensor_list\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52adb283-3a15-4a82-a3e5-b80f2066c891",
   "metadata": {},
   "source": [
    "### 关键点解释\n",
    "\n",
    "1. **系统平台检查**：\n",
    "   - 如果系统平台不是 `Darwin`（macOS），则关闭 JIT 的 profiling 模式和执行器，并启用 CPU 和 GPU 上的张量融合。这些设置可以提高模型的性能。\n",
    "\n",
    "2. **日志记录器**：\n",
    "   - 获取一个日志记录器，用于记录模型构建过程中的信息和调试。\n",
    "\n",
    "3. **文档字符串**：\n",
    "   - `_CHECKPOINT_FOR_DOC` 和 `_CONFIG_FOR_DOC` 是用于文档生成的字符串常量，指示模型和配置的检查点和配置文件。\n",
    "\n",
    "4. **默认初始化函数**：\n",
    "   - `default_init` 函数是一个通用的初始化函数，用于实例化类。\n",
    "\n",
    "5. **`InvalidScoreLogitsProcessor` 类**：\n",
    "   - 这是一个 logits 处理器类，用于处理无效的 logits 分数。如果分数中存在 `NaN` 或无穷大值，将这些值重置为 0，并将特定位置（例如 198）的分数设置为一个很大的值（`5e4`），以确保模型输出有效的结果。\n",
    "\n",
    "6. **`split_tensor_along_last_dim` 函数**：\n",
    "   - 这个函数用于沿着张量的最后一个维度分割张量。参数包括输入张量、要分割的部分数量和一个布尔值（是否使每个分块在内存中是连续的）。如果 `contiguous_split_chunks` 为真，则每个分块在内存中是连续的；否则，直接返回分割后的张量列表。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "7f8005f7-1339-40b4-b103-6bf0396b371a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 旋转位置嵌入类\n",
    "class RotaryEmbedding(nn.Module):\n",
    "    def __init__(self, dim, rope_ratio=1, original_impl=False, device=None, dtype=None):\n",
    "        super().__init__()\n",
    "        # 计算倒数频率\n",
    "        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))\n",
    "        # 注册倒数频率为 buffer\n",
    "        self.register_buffer(\"inv_freq\", inv_freq)\n",
    "        self.dim = dim\n",
    "        self.original_impl = original_impl\n",
    "        self.rope_ratio = rope_ratio\n",
    "\n",
    "    # 实现前向传播的具体方法\n",
    "    def forward_impl(\n",
    "            self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000\n",
    "    ):\n",
    "        \"\"\"\n",
    "        增强的 Transformer 使用旋转位置嵌入。\n",
    "\n",
    "        参考自:\n",
    "        https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/\n",
    "        transformers/rope/__init__.py. MIT 许可证:\n",
    "        https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.\n",
    "        \"\"\"\n",
    "        # 计算 $\\Theta = {\\theta_i = 10000^{\\frac{2(i-1)}{d}}, i \\in [1, 2, ..., \\frac{d}{2}]}$\n",
    "        base = base * self.rope_ratio\n",
    "        theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))\n",
    "\n",
    "        # 创建位置索引 `[0, 1, ..., seq_len - 1]`\n",
    "        seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)\n",
    "\n",
    "        # 计算位置索引和 $\\theta_i$ 的乘积\n",
    "        idx_theta = torch.outer(seq_idx, theta).float()\n",
    "\n",
    "        # 缓存计算的余弦和正弦值\n",
    "        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)\n",
    "\n",
    "        # 模拟 complex32 的行为，否则会得到不同的结果\n",
    "        if dtype in (torch.float16, torch.bfloat16, torch.int8):\n",
    "            cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()\n",
    "        return cache\n",
    "\n",
    "    # 前向传播方法\n",
    "    def forward(self, max_seq_len, offset=0):\n",
    "        return self.forward_impl(\n",
    "            max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device\n",
    "        )\n",
    "\n",
    "# 使用 TorchScript JIT 编译器优化的旋转位置嵌入应用函数\n",
    "@torch.jit.script\n",
    "def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:\n",
    "    # x: [b, np, sq, hn]\n",
    "    b, np, sq, hn = x.size(0), x.size(1), x.size(2), x.size(3)\n",
    "    rot_dim = rope_cache.shape[-2] * 2\n",
    "    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]\n",
    "    # 截断以支持可变大小\n",
    "    rope_cache = rope_cache[:, :sq]\n",
    "    xshaped = x.reshape(b, np, sq, rot_dim // 2, 2)\n",
    "    rope_cache = rope_cache.view(-1, 1, sq, xshaped.size(3), 2)\n",
    "    x_out2 = torch.stack(\n",
    "        [\n",
    "            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],\n",
    "            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],\n",
    "        ],\n",
    "        -1,\n",
    "    )\n",
    "    x_out2 = x_out2.flatten(3)\n",
    "    return torch.cat((x_out2, x_pass), dim=-1)\n",
    "\n",
    "# RMS 归一化层类\n",
    "class RMSNorm(torch.nn.Module):\n",
    "    def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):\n",
    "        super().__init__()\n",
    "        # 初始化权重参数\n",
    "        self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))\n",
    "        self.eps = eps\n",
    "\n",
    "    def forward(self, hidden_states: torch.Tensor):\n",
    "        # 获取输入张量的数据类型\n",
    "        input_dtype = hidden_states.dtype\n",
    "        # 计算方差\n",
    "        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)\n",
    "        # 进行 RMS 归一化\n",
    "        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)\n",
    "        # 应用权重并返回与输入相同的数据类型\n",
    "        return (self.weight * hidden_states).to(input_dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1f66720-11b0-4bc5-9468-083e952ce6de",
   "metadata": {},
   "source": [
    "### 关键点解释\n",
    "\n",
    "1. **`RotaryEmbedding` 类**：\n",
    "   - 用于实现旋转位置嵌入的模块。\n",
    "   - `inv_freq` 是一个倒数频率的张量，用于位置嵌入的计算。\n",
    "   - `forward_impl` 方法根据序列长度和嵌入维度计算旋转位置嵌入。\n",
    "   - `forward` 方法调用 `forward_impl` 进行前向传播。\n",
    "\n",
    "2. **`apply_rotary_pos_emb` 函数**：\n",
    "   - 使用旋转位置嵌入更新输入张量。\n",
    "   - `x` 是输入张量，`rope_cache` 是预计算的旋转位置嵌入。\n",
    "   - 该函数首先对输入张量进行分割和重塑，然后将旋转位置嵌入应用于每个分块，最后将结果拼接回原始张量。\n",
    "\n",
    "3. **`RMSNorm` 类**：\n",
    "   - 实现 RMS 归一化的模块。\n",
    "   - `weight` 是归一化的权重参数。\n",
    "   - `forward` 方法计算输入张量的方差，并进行归一化处理，然后应用权重。\n",
    "\n",
    "这些注释和解释可以帮助理解每个部分的功能和实现细节，对于模型构建和调试非常有用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "b76dc446-e37e-4c1f-94f1-96f209bd674e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CoreAttention(torch.nn.Module):\n",
    "    def __init__(self, config: ChatGLMConfig, layer_number):\n",
    "        super(CoreAttention, self).__init__()\n",
    "\n",
    "        # 配置参数\n",
    "        self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling\n",
    "        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32\n",
    "        if self.apply_query_key_layer_scaling:\n",
    "            self.attention_softmax_in_fp32 = True\n",
    "        self.layer_number = max(1, layer_number)\n",
    "\n",
    "        # 计算投影大小\n",
    "        projection_size = config.kv_channels * config.num_attention_heads\n",
    "\n",
    "        # 每个注意力头和每个分区的值\n",
    "        self.hidden_size_per_partition = projection_size\n",
    "        self.hidden_size_per_attention_head = projection_size // config.num_attention_heads\n",
    "        self.num_attention_heads_per_partition = config.num_attention_heads\n",
    "\n",
    "        coeff = None\n",
    "        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)\n",
    "        if self.apply_query_key_layer_scaling:\n",
    "            coeff = self.layer_number\n",
    "            self.norm_factor *= coeff\n",
    "        self.coeff = coeff\n",
    "\n",
    "        # 注意力 dropout\n",
    "        self.attention_dropout = torch.nn.Dropout(config.attention_dropout)\n",
    "\n",
    "    def forward(self, query_layer, key_layer, value_layer, attention_mask):\n",
    "        pytorch_major_version = int(torch.__version__.split('.')[0])\n",
    "        if pytorch_major_version >= 2:\n",
    "            # PyTorch 2.0 及以上版本\n",
    "            if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:\n",
    "                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,\n",
    "                                                                                 is_causal=True)\n",
    "            else:\n",
    "                if attention_mask is not None:\n",
    "                    attention_mask = ~attention_mask\n",
    "                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,\n",
    "                                                                                 attention_mask)\n",
    "            context_layer = context_layer.transpose(1, 2).contiguous()\n",
    "            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)\n",
    "            context_layer = context_layer.reshape(*new_context_layer_shape)\n",
    "        else:\n",
    "            # 处理 PyTorch 2.0 以下版本\n",
    "\n",
    "            # 原始注意力得分\n",
    "            # [b, np, sq, sk]\n",
    "            output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2))\n",
    "\n",
    "            # 重新调整视图 [b, np, sq, hn] -> [b * np, sq, hn]\n",
    "            query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1)\n",
    "            # 重新调整视图 [b, np, sk, hn] -> [b * np, sk, hn]\n",
    "            key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)\n",
    "\n",
    "            # 预分配输入张量: [b * np, sq, sk]\n",
    "            matmul_input_buffer = torch.empty(\n",
    "                output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,\n",
    "                device=query_layer.device\n",
    "            )\n",
    "\n",
    "            # 计算原始注意力得分. [b * np, sq, sk]\n",
    "            matmul_result = torch.baddbmm(\n",
    "                matmul_input_buffer,\n",
    "                query_layer,  # [b * np, sq, hn]\n",
    "                key_layer.transpose(1, 2),  # [b * np, hn, sk]\n",
    "                beta=0.0,\n",
    "                alpha=(1.0 / self.norm_factor),\n",
    "            )\n",
    "\n",
    "            # 改变视图到 [b, np, sq, sk]\n",
    "            attention_scores = matmul_result.view(*output_size)\n",
    "\n",
    "            # 处理注意力得分和 dropout\n",
    "            if self.attention_softmax_in_fp32:\n",
    "                attention_scores = attention_scores.float()\n",
    "            if self.coeff is not None:\n",
    "                attention_scores = attention_scores * self.coeff\n",
    "            if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:\n",
    "                attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],\n",
    "                                            device=attention_scores.device, dtype=torch.bool)\n",
    "                attention_mask.tril_()\n",
    "                attention_mask = ~attention_mask\n",
    "            if attention_mask is not None:\n",
    "                attention_scores = attention_scores.masked_fill(attention_mask, float(\"-inf\"))\n",
    "            attention_probs = F.softmax(attention_scores, dim=-1)\n",
    "            attention_probs = attention_probs.type_as(value_layer)\n",
    "\n",
    "            # 丢弃整个 token 的注意力，这源自原始的 Transformer 论文\n",
    "            attention_probs = self.attention_dropout(attention_probs)\n",
    "\n",
    "            # 重新调整视图 [b * np, sq, hn]\n",
    "            value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)\n",
    "            # 重新调整视图 [b * np, sq, sk]\n",
    "            attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)\n",
    "            # 计算上下文层 [b * np, sq, hn]\n",
    "            context_layer = torch.bmm(attention_probs, value_layer)\n",
    "            # 重新调整视图 [b, np, sq, hn]\n",
    "            context_layer = context_layer.view(*output_size)\n",
    "            # [b, np, sq, hn] --> [b, sq, np, hn]\n",
    "            context_layer = context_layer.transpose(1, 2).contiguous()\n",
    "            # [b, sq, np, hn] --> [b, sq, hp]\n",
    "            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)\n",
    "            context_layer = context_layer.reshape(*new_context_layer_shape)\n",
    "\n",
    "        return context_layer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a6f04a8-0373-4a78-baa9-52e665fe604d",
   "metadata": {},
   "source": [
    "### CoreAttention 详解及公式\n",
    "\n",
    "在神经网络模型中，特别是 Transformer 架构中，注意力机制起着至关重要的作用。CoreAttention 实现了自注意力机制中的核心部分，具体来说包括以下步骤：计算注意力得分、应用注意力掩码、计算注意力权重以及计算上下文向量。\n",
    "\n",
    "#### 1. 计算注意力得分 (Attention Scores)\n",
    "\n",
    "注意力得分的计算可以表示为矩阵乘法。对于查询 (query) 向量 $Q$ 和键 (key) 向量 $K$，计算注意力得分矩阵 $A$ 的公式为：\n",
    "\\begin{align*} A = \\frac{QK^T}{\\sqrt{d_k}} \\end{align*}\n",
    "其中 $d_k$ 是键向量的维度，这里的 $\\sqrt{d_k}$ 是一个缩放因子，防止得分值过大。\n",
    "\n",
    "在代码中，通过矩阵乘法实现：\n",
    "```python\n",
    "matmul_result = torch.baddbmm(\n",
    "    matmul_input_buffer,\n",
    "    query_layer,  # [b * np, sq, hn]\n",
    "    key_layer.transpose(1, 2),  # [b * np, hn, sk]\n",
    "    beta=0.0,\n",
    "    alpha=(1.0 / self.norm_factor),\n",
    ")\n",
    "```\n",
    "其中 `self.norm_factor` 是 $\\sqrt{d_k}$。\n",
    "\n",
    "#### 2. 应用注意力掩码 (Attention Mask)\n",
    "\n",
    "为了避免模型关注不必要的部分，使用注意力掩码来屏蔽某些位置。注意力掩码通过将相应位置的注意力得分设为负无穷大来实现：\n",
    "```python\n",
    "if attention_mask is not None:\n",
    "    attention_scores = attention_scores.masked_fill(attention_mask, float(\"-inf\"))\n",
    "```\n",
    "\n",
    "#### 3. 计算注意力权重 (Attention Weights)\n",
    "\n",
    "应用 softmax 函数将注意力得分转换为概率分布，表示每个查询向量对键向量的关注程度：\n",
    "\\begin{align*} \\text{AttentionWeights} = \\text{softmax}(A) \\end{align*}\n",
    "代码中实现为：\n",
    "```python\n",
    "attention_probs = F.softmax(attention_scores, dim=-1)\n",
    "attention_probs = attention_probs.type_as(value_layer)\n",
    "```\n",
    "\n",
    "#### 4. 计算上下文向量 (Context Vectors)\n",
    "\n",
    "上下文向量通过将注意力权重与值 (value) 向量相乘并求和得到：\n",
    "\\begin{align*} \\text{Context} = \\text{AttentionWeights} \\cdot V \\end{align*}\n",
    "其中 $V$ 是值向量。\n",
    "\n",
    "在代码中实现为：\n",
    "```python\n",
    "context_layer = torch.bmm(attention_probs, value_layer)\n",
    "```\n",
    "\n",
    "#### 详细原理解释\n",
    "\n",
    "1. **初始化和设置**：\n",
    "    - 初始化类时，会根据配置参数设置注意力缩放、层数等信息。\n",
    "    - `projection_size` 定义了投影大小，`hidden_size_per_attention_head` 和 `num_attention_heads_per_partition` 定义了每个注意力头和每个分区的隐藏层大小。\n",
    "\n",
    "2. **前向传播步骤**：\n",
    "    - **查询、键和值的计算**：计算查询、键和值向量。\n",
    "    - **计算注意力得分**：通过矩阵乘法计算注意力得分矩阵。\n",
    "    - **应用注意力掩码**：将需要屏蔽的位置设置为负无穷大，避免影响后续计算。\n",
    "    - **计算注意力权重**：应用 softmax 函数，得到注意力权重。\n",
    "    - **计算上下文向量**：将注意力权重与值向量相乘，得到上下文向量。\n",
    "\n",
    "3. **具体实现细节**：\n",
    "    - 根据 PyTorch 版本，选择合适的注意力计算方式。\n",
    "    - 对不同维度的张量进行变换，确保形状匹配。\n",
    "    - 使用 dropout 防止过拟合。\n",
    "\n",
    "### 公式和代码对应关系\n",
    "\n",
    "- **注意力得分**：\n",
    "  \\begin{align*}\n",
    "  A = \\frac{QK^T}{\\sqrt{d_k}}\n",
    "  \\end{align*}\n",
    "  对应代码：\n",
    "  ```python\n",
    "  matmul_result = torch.baddbmm(\n",
    "      matmul_input_buffer,\n",
    "      query_layer,\n",
    "      key_layer.transpose(1, 2),\n",
    "      beta=0.0,\n",
    "      alpha=(1.0 / self.norm_factor),\n",
    "  )\n",
    "  ```\n",
    "\n",
    "- **应用注意力掩码**：\n",
    "  \\begin{align*}\n",
    "  A'_{ij} = \\begin{cases} \n",
    "  A_{ij} & \\text{if } \\text{mask}_{ij} = 1 \\\\\n",
    "  -\\infty & \\text{if } \\text{mask}_{ij} = 0 \n",
    "  \\end{cases}\n",
    "  \\end{align*}\n",
    "  对应代码：\n",
    "  ```python\n",
    "  if attention_mask is not None:\n",
    "      attention_scores = attention_scores.masked_fill(attention_mask, float(\"-inf\"))\n",
    "  ```\n",
    "\n",
    "- **注意力权重**：\n",
    "  \\begin{align*}\n",
    "  \\text{AttentionWeights}_{ij} = \\frac{\\exp(A'_{ij})}{\\sum_k \\exp(A'_{ik})}\n",
    "  \\end{align*}\n",
    "  对应代码：\n",
    "  ```python\n",
    "  attention_probs = F.softmax(attention_scores, dim=-1)\n",
    "  ```\n",
    "\n",
    "- **上下文向量**：\n",
    "  \\begin{align*}\n",
    "  \\text{Context}_{i} = \\sum_j \\text{AttentionWeights}_{ij} V_j\n",
    "  \\end{align*}\n",
    "  对应代码：\n",
    "  ```python\n",
    "  context_layer = torch.bmm(attention_probs, value_layer)\n",
    "  ```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47fcf489-71b3-4077-af03-990979abc001",
   "metadata": {},
   "source": [
    "接下来，我们再重温一下注意力机制\n",
    "\n",
    "### 注意力机制原理及公式解释\n",
    "\n",
    "#### 1. 注意力机制是什么？\n",
    "注意力机制（Attention Mechanism）是深度学习中特别是自然语言处理领域中的一种技术，它使模型能够在处理输入序列时动态地关注不同的部分。简单来说，注意力机制让模型在处理某个元素时，可以有选择地关注输入序列的其他部分，而不是全部一视同仁。\n",
    "\n",
    "#### 2. 注意力机制的核心公式\n",
    "我们来看注意力得分的公式：\n",
    "\\begin{align*} A = \\frac{QK^T}{\\sqrt{d_k}} \\end{align*}\n",
    "\n",
    "这里的符号解释如下：\n",
    "- \\( Q \\) 是查询（Query）向量。\n",
    "- \\( K \\) 是键（Key）向量。\n",
    "- \\( d_k \\) 是键向量的维度。\n",
    "- \\( A \\) 是注意力得分矩阵。\n",
    "\n",
    "#### 3. 为什么使用这个公式计算注意力得分？\n",
    "\n",
    "注意力得分公式的核心思想是通过计算查询向量和键向量的点积，来衡量查询与每个键之间的相关性。点积结果越大，表示查询与该键的相关性越强，模型应该对该键对应的值（Value）向量给予更多的关注。\n",
    "\n",
    "公式中的缩放因子 \\(\\sqrt{d_k}\\) 是为了避免点积结果过大，因为如果不缩放，较大的向量维度会导致点积结果非常大，进而导致 softmax 函数输出接近于零的梯度，影响模型训练效果。\n",
    "\n",
    "#### 4. 计算注意力权重\n",
    "得到注意力得分矩阵 \\( A \\) 后，通过 softmax 函数将其转换为注意力权重矩阵：\n",
    "\\begin{align*} \\text{AttentionWeights}_{ij} = \\frac{\\exp(A'_{ij})}{\\sum_k \\exp(A'_{ik})} \\end{align*}\n",
    "其中 \\( A' \\) 是应用掩码后的注意力得分矩阵。\n",
    "\n",
    "#### 5. 计算上下文向量\n",
    "最后，使用注意力权重矩阵对值向量进行加权求和，得到上下文向量：\n",
    "\\begin{align*} \\text{Context}_{i} = \\sum_j \\text{AttentionWeights}_{ij} V_j \\end{align*}\n",
    "这里 \\( V \\) 是值向量。\n",
    "\n",
    "### 注意力机制与人类注意力的关系\n",
    "\n",
    "注意力机制与人类的注意力有一定的相似之处，但也有显著的区别。\n",
    "\n",
    "- **相似之处**：\n",
    "  - **选择性关注**：就像人类在阅读一篇文章时会选择性地关注某些重要段落，忽略其他部分，注意力机制也让模型在处理一个序列时可以选择性地关注不同的部分。\n",
    "  - **动态调整**：人类的注意力是动态的，会根据上下文调整关注点。注意力机制也是动态的，可以根据输入的变化调整注意力权重。\n",
    "\n",
    "- **区别**：\n",
    "  - **机制不同**：人类注意力是通过大脑的复杂神经网络实现的，包括视觉、听觉等多种感官信息的综合处理。而注意力机制是一种数学方法，通过点积、softmax 等操作实现。\n",
    "  - **目的不同**：人类注意力用于理解和互动，而注意力机制主要用于提高模型在处理长序列数据时的性能和效率。\n",
    "\n",
    "### 具体代码实现中的细节\n",
    "\n",
    "在 `CoreAttention` 类中，注意力得分的计算和应用通过以下代码片段实现：\n",
    "```python\n",
    "matmul_result = torch.baddbmm(\n",
    "    matmul_input_buffer,\n",
    "    query_layer,  # [b * np, sq, hn]\n",
    "    key_layer.transpose(1, 2),  # [b * np, hn, sk]\n",
    "    beta=0.0,\n",
    "    alpha=(1.0 / self.norm_factor),\n",
    ")\n",
    "```\n",
    "这里 `torch.baddbmm` 函数执行的是批量矩阵乘法，计算公式中的 \\(QK^T\\) 部分，并除以 \\(\\sqrt{d_k}\\) 进行缩放。\n",
    "\n",
    "应用 softmax 得到注意力权重：\n",
    "```python\n",
    "attention_probs = F.softmax(attention_scores, dim=-1)\n",
    "```\n",
    "\n",
    "最后，计算上下文向量：\n",
    "```python\n",
    "context_layer = torch.bmm(attention_probs, value_layer)\n",
    "```\n",
    "\n",
    "这一步将注意力权重与值向量相乘并求和，得到最终的上下文表示。\n",
    "\n",
    "通过上述解释，可以更好地理解注意力机制的原理和实现，以及它在模型中的重要作用。\n",
    "\n",
    "完成注意力机制的核心实现后，我们构建自注意力"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f0b003c-7e42-4e0a-96fe-a052826ea982",
   "metadata": {},
   "source": [
    "### SelfAttention 类的原理及公式解释\n",
    "\n",
    "#### SelfAttention 类的功能\n",
    "\n",
    "SelfAttention 类实现了自注意力机制（Self-Attention Mechanism），这是 Transformer 模型的核心部分。自注意力机制的目标是让每个位置的表示能够动态地关注输入序列的其他位置，从而捕捉全局信息。\n",
    "\n",
    "#### 自注意力机制的步骤及公式\n",
    "\n",
    "自注意力机制包括以下几个关键步骤：\n",
    "\n",
    "1. **线性变换**：\n",
    "   输入的隐藏状态 \\(X\\) 通过线性层分别映射到查询 \\(Q\\)、键 \\(K\\) 和值 \\(V\\) 三个空间：\n",
    "   \\begin{align*}\n",
    "   Q = XW_Q, \\quad K = XW_K, \\quad V = XW_V\n",
    "   \\end{align*}\n",
    "   其中，\\(W_Q\\)、\\(W_K\\) 和 \\(W_V\\) 是学习到的权重矩阵。\n",
    "\n",
    "2. **计算注意力得分**：\n",
    "   通过计算查询 \\(Q\\) 和键 \\(K\\) 的点积并除以缩放因子 \\(\\sqrt{d_k}\\) 得到注意力得分矩阵 \\(A\\)：\n",
    "   \\begin{align*}\n",
    "   A = \\frac{QK^T}{\\sqrt{d_k}}\n",
    "   \\end{align*}\n",
    "\n",
    "3. **应用注意力掩码**：\n",
    "   对于自注意力机制，如果使用掩码（例如在解码阶段），将不需要关注的位置设为负无穷大以屏蔽：\n",
    "   \\begin{align*}\n",
    "   A'_{ij} = \\begin{cases} \n",
    "   A_{ij} & \\text{if } \\text{mask}_{ij} = 1 \\\\\n",
    "   -\\infty & \\text{if } \\text{mask}_{ij} = 0 \n",
    "   \\end{cases}\n",
    "   \\end{align*}\n",
    "\n",
    "4. **计算注意力权重**：\n",
    "   对注意力得分矩阵 \\(A'\\) 应用 softmax 函数，得到注意力权重矩阵：\n",
    "   \\begin{align*}\n",
    "   \\text{AttentionWeights}_{ij} = \\frac{\\exp(A'_{ij})}{\\sum_k \\exp(A'_{ik})}\n",
    "   \\end{align*}\n",
    "\n",
    "5. **计算上下文向量**：\n",
    "   使用注意力权重对值 \\(V\\) 进行加权求和，得到上下文向量：\n",
    "   \\begin{align*}\n",
    "   \\text{Context} = \\text{AttentionWeights} \\cdot V\n",
    "   \\end{align*}\n",
    "\n",
    "### SelfAttention 类的具体实现\n",
    "\n",
    "以下是 SelfAttention 类中的核心步骤和公式的实现细节：\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "691ee97d-c7a7-4a41-846c-eb5f4a33443a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SelfAttention(torch.nn.Module):\n",
    "    \"\"\"并行自注意力层的抽象类。\n",
    "\n",
    "    自注意力层接受形状为 [s, b, h] 的输入，并返回相同形状的输出。\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, config: ChatGLMConfig, layer_number, device=None):\n",
    "        super(SelfAttention, self).__init__()\n",
    "        self.layer_number = max(1, layer_number)\n",
    "\n",
    "        self.projection_size = config.kv_channels * config.num_attention_heads\n",
    "\n",
    "        # 每个注意力头和每个分区的值\n",
    "        self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads\n",
    "        self.num_attention_heads_per_partition = config.num_attention_heads\n",
    "\n",
    "        self.multi_query_attention = config.multi_query_attention\n",
    "        self.qkv_hidden_size = 3 * self.projection_size\n",
    "        if self.multi_query_attention:\n",
    "            self.num_multi_query_groups_per_partition = config.multi_query_group_num\n",
    "            self.qkv_hidden_size = (\n",
    "                    self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num\n",
    "            )\n",
    "        self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,\n",
    "                                         bias=config.add_bias_linear or config.add_qkv_bias,\n",
    "                                         device=device, **_config_to_kwargs(config)\n",
    "                                         )\n",
    "\n",
    "        self.core_attention = CoreAttention(config, self.layer_number)\n",
    "\n",
    "        # 输出层\n",
    "        self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,\n",
    "                               device=device, **_config_to_kwargs(config)\n",
    "                               )\n",
    "\n",
    "    def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True):\n",
    "        # hidden_states: [b, sq, h]\n",
    "\n",
    "        # =====================\n",
    "        # Query, Key 和 Value\n",
    "        # =====================\n",
    "\n",
    "        # 注意力头 [b, sq, h] --> [b, sq, (np * 3 * hn)]\n",
    "        mixed_x_layer = self.query_key_value(hidden_states)\n",
    "\n",
    "        if self.multi_query_attention:\n",
    "            (query_layer, key_layer, value_layer) = mixed_x_layer.split(\n",
    "                [\n",
    "                    self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,\n",
    "                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,\n",
    "                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,\n",
    "                ],\n",
    "                dim=-1,\n",
    "            )\n",
    "            query_layer = query_layer.view(\n",
    "                query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)\n",
    "            )\n",
    "            key_layer = key_layer.view(\n",
    "                key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)\n",
    "            )\n",
    "            value_layer = value_layer.view(\n",
    "                value_layer.size()[:-1]\n",
    "                + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)\n",
    "            )\n",
    "        else:\n",
    "            new_tensor_shape = mixed_x_layer.size()[:-1] + \\\n",
    "                               (self.num_attention_heads_per_partition,\n",
    "                                3 * self.hidden_size_per_attention_head)\n",
    "            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)\n",
    "\n",
    "            # [b, sq, np, 3 * hn] --> 3 [b, sq, np, hn]\n",
    "            (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)\n",
    "\n",
    "        # [b, sq, np, hn] -> [b, np, sq, hn]\n",
    "        query_layer, key_layer, value_layer = [k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]]\n",
    "\n",
    "        # 应用相对位置编码（旋转嵌入）\n",
    "        if rotary_pos_emb is not None:\n",
    "            query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)\n",
    "            key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)\n",
    "\n",
    "        # 调整 key 和 value 以用于推理\n",
    "        if kv_cache is not None:\n",
    "            cache_k, cache_v = kv_cache\n",
    "            key_layer = torch.cat((cache_k, key_layer), dim=2)\n",
    "            value_layer = torch.cat((cache_v, value_layer), dim=2)\n",
    "        if use_cache:\n",
    "            if kv_cache is None:\n",
    "                kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1)\n",
    "            else:\n",
    "                kv_cache = (key_layer, value_layer)\n",
    "        else:\n",
    "            kv_cache = None\n",
    "\n",
    "        if self.multi_query_attention:\n",
    "            key_layer = key_layer.unsqueeze(2)\n",
    "            key_layer = key_layer.expand(\n",
    "                -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1\n",
    "            )\n",
    "            key_layer = key_layer.contiguous().view(\n",
    "                key_layer.size()[:1] + (self.num_attention_heads_per_partition,) + key_layer.size()[3:]\n",
    "            )\n",
    "            value_layer = value_layer.unsqueeze(2)\n",
    "            value_layer = value_layer.expand(\n",
    "                -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1\n",
    "            )\n",
    "            value_layer = value_layer.contiguous().view(\n",
    "                value_layer.size()[:1] + (self.num_attention_heads_per_partition,) + value_layer.size()[3:]\n",
    "            )\n",
    "\n",
    "        # 核心注意力计算\n",
    "        context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)\n",
    "\n",
    "        # 输出. [sq, b, h]\n",
    "        output = self.dense(context_layer)\n",
    "\n",
    "        return output, kv_cache"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "345e5da0-10a8-4c65-83ce-072f59c3b376",
   "metadata": {},
   "source": [
    "### 详细原理解释\n",
    "\n",
    "1. **线性变换**：\n",
    "   ```python\n",
    "   mixed_x_layer = self.query_key_value(hidden_states)\n",
    "   ```\n",
    "   输入隐藏状态 `hidden_states` 通过线性层映射到查询、键和值的空间。\n",
    "\n",
    "2. **拆分查询、键和值**：\n",
    "   ```python\n",
    "   (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)\n",
    "   ```\n",
    "   将线性变换后的结果拆分成查询、键和值。\n",
    "\n",
    "3. **形状变换**：\n",
    "   ```python\n",
    "   query_layer, key_layer, value_layer = [k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]]\n",
    "   ```\n",
    "   将查询、键和值的形状从 `[b, sq, np, hn]` 转换为 `[b, np, sq, hn]`。\n",
    "\n",
    "4. **应用旋转位置编码**：\n",
    "   ```python\n",
    "   if rotary_pos_emb is not None:\n",
    "       query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)\n",
    "       key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)\n",
    "   ```\n",
    "   如果存在旋转位置编码，则对查询和键应用位置编码。\n",
    "\n",
    "5. **调整键和值用于推理**：\n",
    "   ```python\n",
    "   if kv_cache is not None:\n",
    "       cache_k, cache_v = kv_cache\n",
    "      \n",
    "\n",
    " key_layer = torch.cat((cache_k, key_layer), dim=2)\n",
    "       value_layer = torch.cat((cache_v, value_layer), dim=2)\n",
    "   if use_cache:\n",
    "       if kv_cache is None:\n",
    "           kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1)\n",
    "       else:\n",
    "           kv_cache = (key_layer, value_layer)\n",
    "   else:\n",
    "       kv_cache = None\n",
    "   ```\n",
    "\n",
    "6. **核心注意力计算**：\n",
    "   ```python\n",
    "   context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)\n",
    "   ```\n",
    "\n",
    "7. **输出层**：\n",
    "   ```python\n",
    "   output = self.dense(context_layer)\n",
    "   ```\n",
    "\n",
    "通过上述详细的原理解释和公式，可以更好地理解 SelfAttention 类的实现以及其在 Transformer 模型中的作用。自注意力机制通过动态调整不同位置之间的权重，使得模型能够更有效地捕捉全局信息，从而提高模型的性能和泛化能力。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "00c11e28-47e7-437b-aab8-1da8b736443a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _config_to_kwargs(args):\n",
    "    common_kwargs = {\n",
    "        \"dtype\": args.torch_dtype,\n",
    "    }\n",
    "    return common_kwargs\n",
    "\n",
    "\n",
    "class MLP(torch.nn.Module):\n",
    "    \"\"\"多层感知机（MLP）。\n",
    "\n",
    "    MLP 将接受隐藏状态为 h 的输入，将其投影到 4*h 的隐藏维度，执行非线性变换，然后将状态投影回 h 的隐藏维度。\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, config: ChatGLMConfig, device=None):\n",
    "        super(MLP, self).__init__()\n",
    "\n",
    "        self.add_bias = config.add_bias_linear\n",
    "\n",
    "        # 投影到 4h。如果使用 swiglu 则将输出宽度加倍，详见 https://arxiv.org/pdf/2002.05202.pdf\n",
    "        self.dense_h_to_4h = nn.Linear(\n",
    "            config.hidden_size,\n",
    "            config.ffn_hidden_size * 2,\n",
    "            bias=self.add_bias,\n",
    "            device=device,\n",
    "            **_config_to_kwargs(config)\n",
    "        )\n",
    "\n",
    "        def swiglu(x):\n",
    "            x = torch.chunk(x, 2, dim=-1)\n",
    "            return F.silu(x[0]) * x[1]\n",
    "\n",
    "        self.activation_func = swiglu\n",
    "\n",
    "        # 投影回 h.\n",
    "        self.dense_4h_to_h = nn.Linear(\n",
    "            config.ffn_hidden_size,\n",
    "            config.hidden_size,\n",
    "            bias=self.add_bias,\n",
    "            device=device,\n",
    "            **_config_to_kwargs(config)\n",
    "        )\n",
    "\n",
    "    def forward(self, hidden_states):\n",
    "        # [s, b, 4hp]\n",
    "        intermediate_parallel = self.dense_h_to_4h(hidden_states)\n",
    "        intermediate_parallel = self.activation_func(intermediate_parallel)\n",
    "        # [s, b, h]\n",
    "        output = self.dense_4h_to_h(intermediate_parallel)\n",
    "        return output\n",
    "\n",
    "\n",
    "class GLMBlock(torch.nn.Module):\n",
    "    \"\"\"一个 transformer 层。\n",
    "\n",
    "    Transformer 层接受形状为 [s, b, h] 的输入，并返回相同形状的输出。\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, config: ChatGLMConfig, layer_number, device=None):\n",
    "        super(GLMBlock, self).__init__()\n",
    "        self.layer_number = layer_number\n",
    "\n",
    "        self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm\n",
    "        self.fp32_residual_connection = config.fp32_residual_connection\n",
    "\n",
    "        LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm\n",
    "        # 输入数据上的层归一化\n",
    "        self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,\n",
    "                                             dtype=config.torch_dtype)\n",
    "\n",
    "        # 自注意力层\n",
    "        self.self_attention = SelfAttention(config, layer_number, device=device)\n",
    "        self.hidden_dropout = config.hidden_dropout\n",
    "\n",
    "        # 注意力输出上的层归一化\n",
    "        self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,\n",
    "                                                      dtype=config.torch_dtype)\n",
    "\n",
    "        # 多层感知机（MLP）\n",
    "        self.mlp = MLP(config, device=device)\n",
    "\n",
    "    def forward(\n",
    "            self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,\n",
    "    ):\n",
    "        # hidden_states: [s, b, h]\n",
    "\n",
    "        # 在 transformer 层开始时的层归一化\n",
    "        layernorm_output = self.input_layernorm(hidden_states)\n",
    "        # 自注意力层\n",
    "        attention_output, kv_cache = self.self_attention(\n",
    "            layernorm_output,\n",
    "            attention_mask,\n",
    "            rotary_pos_emb,\n",
    "            kv_cache=kv_cache,\n",
    "            use_cache=use_cache\n",
    "        )\n",
    "\n",
    "        # 残差连接\n",
    "        if self.apply_residual_connection_post_layernorm:\n",
    "            residual = layernorm_output\n",
    "        else:\n",
    "            residual = hidden_states\n",
    "\n",
    "        layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)\n",
    "        layernorm_input = residual + layernorm_input\n",
    "\n",
    "        # 自注意力后的层归一化\n",
    "        layernorm_output = self.post_attention_layernorm(layernorm_input)\n",
    "\n",
    "        # 多层感知机（MLP）\n",
    "        mlp_output = self.mlp(layernorm_output)\n",
    "\n",
    "        # 第二个残差连接\n",
    "        if self.apply_residual_connection_post_layernorm:\n",
    "            residual = layernorm_output\n",
    "        else:\n",
    "            residual = layernorm_input\n",
    "\n",
    "        output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)\n",
    "        output = residual + output\n",
    "\n",
    "        return output, kv_cache\n",
    "\n",
    "\n",
    "class GLMTransformer(torch.nn.Module):\n",
    "    \"\"\"Transformer 类。\"\"\"\n",
    "\n",
    "    def __init__(self, config: ChatGLMConfig, device=None):\n",
    "        super(GLMTransformer, self).__init__()\n",
    "\n",
    "        self.fp32_residual_connection = config.fp32_residual_connection\n",
    "        self.post_layer_norm = config.post_layer_norm\n",
    "\n",
    "        # 层数\n",
    "        self.num_layers = config.num_layers\n",
    "\n",
    "        # Transformer 层\n",
    "        def build_layer(layer_number):\n",
    "            return GLMBlock(config, layer_number, device=device)\n",
    "\n",
    "        self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])\n",
    "\n",
    "        if self.post_layer_norm:\n",
    "            LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm\n",
    "            # 输出前的最终层归一化\n",
    "            self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,\n",
    "                                                 dtype=config.torch_dtype)\n",
    "\n",
    "        self.gradient_checkpointing = False\n",
    "\n",
    "    def _get_layer(self, layer_number):\n",
    "        return self.layers[layer_number]\n",
    "\n",
    "    def forward(\n",
    "            self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,\n",
    "            use_cache: Optional[bool] = True,\n",
    "            output_hidden_states: Optional[bool] = False,\n",
    "    ):\n",
    "        if not kv_caches:\n",
    "            kv_caches = [None for _ in range(self.num_layers)]\n",
    "        presents = () if use_cache else None\n",
    "        if self.gradient_checkpointing and self.training:\n",
    "            if use_cache:\n",
    "                logger.warning_once(\n",
    "                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n",
    "                )\n",
    "                use_cache = False\n",
    "\n",
    "        all_self_attentions = None\n",
    "        all_hidden_states = () if output_hidden_states else None\n",
    "        for index in range(self.num_layers):\n",
    "            if output_hidden_states:\n",
    "                all_hidden_states = all_hidden_states + (hidden_states,)\n",
    "\n",
    "            layer = self._get_layer(index)\n",
    "            if self.gradient_checkpointing and self.training:\n",
    "                layer_ret = torch.utils.checkpoint.checkpoint(\n",
    "                    layer,\n",
    "                    hidden_states,\n",
    "                    attention_mask,\n",
    "                    rotary_pos_emb,\n",
    "                    kv_caches[index],\n",
    "                    use_cache,\n",
    "                    use_reentrant=False\n",
    "                )\n",
    "            else:\n",
    "                layer_ret = layer(\n",
    "                    hidden_states,\n",
    "                    attention_mask,\n",
    "                    rotary_pos_emb,\n",
    "                    kv_cache=kv_caches[index],\n",
    "                    use_cache=use_cache\n",
    "                )\n",
    "            hidden_states, kv_cache = layer_ret\n",
    "            if use_cache:\n",
    "                # token by token 解码，使用元组格式\n",
    "                if kv_caches[0] is not None:\n",
    "                    presents = presents + (kv_cache,)\n",
    "                # 预填充解码，使用张量格式以节省 CUDA 内存\n",
    "                else:\n",
    "                    if len(presents) == 0:\n",
    "                        presents = kv_cache\n",
    "                    else:\n",
    "                        presents = torch.cat((presents, kv_cache.to(presents.device)), dim=0)\n",
    "\n",
    "        if output_hidden_states:\n",
    "            all_hidden_states = all_hidden_states + (hidden_states,)\n",
    "\n",
    "        # 最终层归一化\n",
    "        if self.post_layer_norm:\n",
    "            hidden_states = self.final_layernorm(hidden_states)\n",
    "\n",
    "        return hidden_states, presents, all_hidden_states, all_self_attentions\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ad537c5-ceb0-41a7-b889-fc0fe633bab6",
   "metadata": {},
   "source": [
    "## 张量形状变化及其意义。\n",
    "\n",
    "### `CoreAttention` 类中的张量形状\n",
    "\n",
    "1. **`query_layer`, `key_layer`, `value_layer` 形状变化**:\n",
    "   ```python\n",
    "   # 输入形状 [b, np, sq, hn] 和 [b, np, sk, hn]\n",
    "   # [b, np, sq, hn] -> [b * np, sq, hn]\n",
    "   query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1)\n",
    "   # [b, np, sk, hn] -> [b * np, sk, hn]\n",
    "   key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)\n",
    "   ```\n",
    "\n",
    "2. **原始注意力得分的计算**:\n",
    "   ```python\n",
    "   # 输入形状 [b * np, sq, hn] 和 [b * np, hn, sk]\n",
    "   # 结果形状 [b * np, sq, sk]\n",
    "   matmul_result = torch.baddbmm(matmul_input_buffer, query_layer, key_layer.transpose(1, 2), beta=0.0, alpha=(1.0 / self.norm_factor))\n",
    "   ```\n",
    "\n",
    "3. **调整视图到原始形状**:\n",
    "   ```python\n",
    "   # [b * np, sq, sk] -> [b, np, sq, sk]\n",
    "   attention_scores = matmul_result.view(*output_size)\n",
    "   ```\n",
    "\n",
    "4. **注意力概率形状变化**:\n",
    "   ```python\n",
    "   # 输入形状 [b, np, sq, sk]\n",
    "   # 调整后的形状 [b * np, sq, sk]\n",
    "   attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)\n",
    "   ```\n",
    "\n",
    "5. **计算上下文层**:\n",
    "   ```python\n",
    "   # 输入形状 [b * np, sq, sk] 和 [b * np, sk, hn]\n",
    "   # 结果形状 [b * np, sq, hn]\n",
    "   context_layer = torch.bmm(attention_probs, value_layer)\n",
    "   # 调整视图 [b * np, sq, hn] -> [b, np, sq, hn]\n",
    "   context_layer = context_layer.view(*output_size)\n",
    "   # [b, np, sq, hn] -> [b, sq, np, hn]\n",
    "   context_layer = context_layer.transpose(1, 2).contiguous()\n",
    "   # [b, sq, np, hn] -> [b, sq, hp]\n",
    "   new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)\n",
    "   context_layer = context_layer.reshape(*new_context_layer_shape)\n",
    "   ```\n",
    "\n",
    "### `SelfAttention` 类中的张量形状\n",
    "\n",
    "1. **`mixed_x_layer` 形状变化**:\n",
    "   ```python\n",
    "   # 输入形状 [b, sq, h]\n",
    "   # 结果形状 [b, sq, (np * 3 * hn)]\n",
    "   mixed_x_layer = self.query_key_value(hidden_states)\n",
    "   ```\n",
    "\n",
    "2. **多查询注意力的形状变化**:\n",
    "   ```python\n",
    "   # [b, sq, (np * 3 * hn)] -> [b, sq, np, hn] 和 [b, sq, np, hn]\n",
    "   (query_layer, key_layer, value_layer) = mixed_x_layer.split([...], dim=-1)\n",
    "   # 调整视图\n",
    "   query_layer = query_layer.view(query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))\n",
    "   key_layer = key_layer.view(key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head))\n",
    "   value_layer = value_layer.view(value_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head))\n",
    "   ```\n",
    "\n",
    "3. **普通注意力的形状变化**:\n",
    "   ```python\n",
    "   # [b, sq, (np * 3 * hn)] -> [b, sq, np, 3 * hn]\n",
    "   new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head)\n",
    "   mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)\n",
    "   # [b, sq, np, 3 * hn] -> 3 [b, sq, np, hn]\n",
    "   (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)\n",
    "   ```\n",
    "\n",
    "4. **转置操作**:\n",
    "   ```python\n",
    "   # [b, sq, np, hn] -> [b, np, sq, hn]\n",
    "   query_layer, key_layer, value_layer = [k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]]\n",
    "   ```\n",
    "\n",
    "### `MLP` 类中的张量形状\n",
    "\n",
    "1. **前向传播**:\n",
    "   ```python\n",
    "   # 输入形状 [s, b, h]\n",
    "   # 经过 dense_h_to_4h 线性层后形状 [s, b, 4hp]\n",
    "   intermediate_parallel = self.dense_h_to_4h(hidden_states)\n",
    "   intermediate_parallel = self.activation_func(intermediate_parallel)\n",
    "   # 经过 dense_4h_to_h 线性层后形状 [s, b, h]\n",
    "   output = self.dense_4h_to_h(intermediate_parallel)\n",
    "   ```\n",
    "\n",
    "### `GLMBlock` 类中的张量形状\n",
    "\n",
    "1. **自注意力层的输入输出形状**:\n",
    "   ```python\n",
    "   # 输入形状 [s, b, h]\n",
    "   # 自注意力层输出形状 [s, b, h]\n",
    "   attention_output, kv_cache = self.self_attention(layernorm_output, attention_mask, rotary_pos_emb, kv_cache=kv_cache, use_cache=use_cache)\n",
    "   ```\n",
    "\n",
    "2. **残差连接**:\n",
    "   ```python\n",
    "   # 残差连接，形状保持不变 [s, b, h]\n",
    "   layernorm_input = residual + layernorm_input\n",
    "   ```\n",
    "\n",
    "3. **多层感知机（MLP）的输入输出形状**:\n",
    "   ```python\n",
    "   # MLP 输出形状 [s, b, h]\n",
    "   mlp_output = self.mlp(layernorm_output)\n",
    "   ```\n",
    "\n",
    "### `GLMTransformer` 类中的张量形状\n",
    "\n",
    "1. **逐层处理**:\n",
    "   ```python\n",
    "   # 每层的输入输出形状 [s, b, h]\n",
    "   for index in range(self.num_layers):\n",
    "       layer = self._get_layer(index)\n",
    "       hidden_states, kv_cache = layer(hidden_states, attention_mask, rotary_pos_emb, kv_cache=kv_caches[index], use_cache=use_cache)\n",
    "   ```\n",
    "\n",
    "2. **最终层归一化**:\n",
    "   ```python\n",
    "   # 最终层归一化，形状 [s, b, h]\n",
    "   if self.post_layer_norm:\n",
    "       hidden_states = self.final_layernorm(hidden_states)\n",
    "   ```\n",
    "\n",
    "总结这些形状变化，有助于理解每个层的输入输出如何传递和处理，确保模型在每个步骤中保持正确的张量形状。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b59c956-a31d-459c-a22d-ef3920644ab5",
   "metadata": {},
   "source": [
    "### 关键点解释\n",
    "\n",
    "1. **`CoreAttention` 类**：\n",
    "   - 实现了核心注意力机制。\n",
    "   - 支持 PyTorch 2.0 及以上版本的 `scaled_dot_product_attention`，以及旧版本的自定义注意力计算。\n",
    "\n",
    "2. **`SelfAttention` 类**：\n",
    "   - 实现了自注意力层。\n",
    "   - 包括 query、key、value 的计算以及核心注意力机制的应用。\n",
    "   - 支持多查询注意力机制。\n",
    "\n",
    "3. **`MLP` 类**：\n",
    "   - 多层感知机（MLP），包括两个线性层和一个非线性激活函数。\n",
    "\n",
    "4. **`GLMBlock` 类**：\n",
    "   - 实现了一个 Transformer 层，包括自注意力层和 MLP 层。\n",
    "   - 包括层归一化和残差连接。\n",
    "\n",
    "5. **`GLMTransformer` 类**：\n",
    "   - 实现了 Transformer 模型，包括多个 Transformer 层。\n",
    "   - 支持梯度检查点和最终层归一化。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48414b7e-3bb1-4a06-b411-5492035dabcf",
   "metadata": {},
   "source": [
    "### 自注意力机制提取输入隐藏状态的步骤\n",
    "\n",
    "在 SelfAttention 类中，自注意力机制通过一系列的步骤来提取和处理输入的隐藏状态（hidden states），最终生成上下文向量。这些步骤包括线性变换、计算注意力得分、应用注意力掩码、计算注意力权重和生成上下文向量。以下是详细的步骤和解释：\n",
    "\n",
    "#### 1. 输入隐藏状态\n",
    "\n",
    "输入隐藏状态 \\( \\text{hidden_states} \\) 的形状通常为 \\([b, sq, h]\\)，其中：\n",
    "- \\( b \\) 是批次大小（batch size）。\n",
    "- \\( sq \\) 是序列长度（sequence length）。\n",
    "- \\( h \\) 是隐藏层维度（hidden size）。\n",
    "\n",
    "#### 2. 线性变换\n",
    "\n",
    "将输入隐藏状态 \\( \\text{hidden_states} \\) 通过线性层投影到查询 \\( Q \\)、键 \\( K \\) 和值 \\( V \\) 空间。这一步的目的是将输入映射到不同的子空间，以便进行注意力计算。公式如下：\n",
    "\\begin{align*} Q = \\text{hidden_states} \\cdot W_Q \\end{align*}\n",
    "\\begin{align*} K = \\text{hidden_states} \\cdot W_K \\end{align*}\n",
    "\\begin{align*} V = \\text{hidden_states} \\cdot W_V \\end{align*}\n",
    "\n",
    "在代码中，通过以下方式实现：\n",
    "```python\n",
    "mixed_x_layer = self.query_key_value(hidden_states)\n",
    "```\n",
    "这里 `self.query_key_value` 是一个线性层，它将输入隐藏状态投影到一个更大的空间，结果的形状为 \\([b, sq, (3 \\times \\text{hidden_size})]\\)。\n",
    "\n",
    "#### 3. 拆分查询、键和值\n",
    "\n",
    "将上述结果拆分成查询、键和值向量。拆分后，每个向量的形状为 \\([b, sq, \\text{hidden_size}]\\)。\n",
    "```python\n",
    "(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)\n",
    "```\n",
    "`split_tensor_along_last_dim` 函数按最后一个维度将张量分割成三部分。\n",
    "\n",
    "#### 4. 形状变换\n",
    "\n",
    "为了适应后续的矩阵乘法操作，需要调整查询、键和值的形状。将它们从 \\([b, sq, np, hn]\\) 转换为 \\([b, np, sq, hn]\\)，其中 \\( np \\) 是注意力头的数量，\\( hn \\) 是每个注意力头的维度。\n",
    "```python\n",
    "query_layer, key_layer, value_layer = [k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]]\n",
    "```\n",
    "\n",
    "#### 5. 应用旋转位置编码\n",
    "\n",
    "如果存在旋转位置编码（rotary position embedding），则应用到查询和键向量上。旋转位置编码可以帮助模型更好地捕捉位置信息。\n",
    "```python\n",
    "if rotary_pos_emb is not None:\n",
    "    query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)\n",
    "    key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)\n",
    "```\n",
    "\n",
    "#### 6. 调整键和值用于推理\n",
    "\n",
    "在推理阶段，可能需要缓存键和值，以便在下一个时间步中重复使用，从而提高效率。将缓存的键和值与当前时间步的键和值进行拼接。\n",
    "```python\n",
    "if kv_cache is not None:\n",
    "    cache_k, cache_v = kv_cache\n",
    "    key_layer = torch.cat((cache_k, key_layer), dim=2)\n",
    "    value_layer = torch.cat((cache_v, value_layer), dim=2)\n",
    "if use_cache:\n",
    "    if kv_cache is None:\n",
    "        kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1)\n",
    "    else:\n",
    "        kv_cache = (key_layer, value_layer)\n",
    "else:\n",
    "    kv_cache = None\n",
    "```\n",
    "\n",
    "#### 7. 核心注意力计算\n",
    "\n",
    "使用 CoreAttention 类来计算注意力得分、注意力权重和上下文向量。\n",
    "```python\n",
    "context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)\n",
    "```\n",
    "\n",
    "#### 8. 输出层\n",
    "\n",
    "最后，将上下文向量通过线性层投影回原始的隐藏层维度。这一步是将计算后的结果映射回输入的形状，以便进行后续处理。\n",
    "```python\n",
    "output = self.dense(context_layer)\n",
    "```\n",
    "\n",
    "### 总结\n",
    "\n",
    "通过上述步骤，自注意力机制实现了从输入隐藏状态中提取和处理信息的过程。每一步的详细解释如下：\n",
    "\n",
    "1. **输入隐藏状态**：输入序列的隐藏表示。\n",
    "2. **线性变换**：将隐藏表示投影到查询、键和值的空间。\n",
    "3. **拆分查询、键和值**：将投影后的结果拆分成查询、键和值向量。\n",
    "4. **形状变换**：调整查询、键和值的形状，以适应后续操作。\n",
    "5. **应用旋转位置编码**：增强查询和键向量的位置信息。\n",
    "6. **调整键和值用于推理**：在推理阶段缓存键和值，以提高效率。\n",
    "7. **核心注意力计算**：通过计算注意力得分、权重和上下文向量，完成注意力机制的核心部分。\n",
    "8. **输出层**：将上下文向量投影回原始的隐藏层维度。\n",
    "\n",
    "这些步骤共同构成了自注意力机制，从而使模型能够动态地关注输入序列中的不同部分，捕捉全局信息并生成更丰富的表示。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d16f0c2-db97-4512-9a09-1e59502d6e19",
   "metadata": {},
   "source": [
    "### ChatGLMPreTrainedModel 类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "9d1e5174-cc11-4ba0-ad18-b58110bf084b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ChatGLMPreTrainedModel(PreTrainedModel):\n",
    "    \"\"\"\n",
    "    处理权重初始化和下载及加载预训练模型的简单接口的抽象类。\n",
    "    \"\"\"\n",
    "    is_parallelizable = False  # 是否可并行化\n",
    "    supports_gradient_checkpointing = True  # 是否支持梯度检查点\n",
    "    config_class = ChatGLMConfig  # 配置类\n",
    "    base_model_prefix = \"transformer\"  # 基础模型前缀\n",
    "    _no_split_modules = [\"GLMBlock\"]  # 不拆分的模块列表\n",
    "\n",
    "    def _init_weights(self, module: nn.Module):\n",
    "        \"\"\"初始化权重\"\"\"\n",
    "        return\n",
    "\n",
    "    def get_masks(self, input_ids, past_key_values, padding_mask=None):\n",
    "        \"\"\"\n",
    "        获取注意力掩码\n",
    "        \"\"\"\n",
    "        batch_size, seq_length = input_ids.shape\n",
    "        # 创建下三角矩阵作为全注意力掩码\n",
    "        full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)\n",
    "        full_attention_mask.tril_()\n",
    "        past_length = 0\n",
    "        if past_key_values:\n",
    "            past_length = past_key_values[0][0].shape[2]\n",
    "        if past_length:\n",
    "            full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, device=input_ids.device), full_attention_mask), dim=-1)\n",
    "        if padding_mask is not None:\n",
    "            full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)\n",
    "        if not past_length and padding_mask is not None:\n",
    "            full_attention_mask -= padding_mask.unsqueeze(-1) - 1\n",
    "        full_attention_mask = (full_attention_mask < 0.5).bool()\n",
    "        full_attention_mask.unsqueeze_(1)\n",
    "        return full_attention_mask\n",
    "\n",
    "    def get_position_ids(self, input_ids, device):\n",
    "        \"\"\"\n",
    "        获取位置ID\n",
    "        \"\"\"\n",
    "        batch_size, seq_length = input_ids.shape\n",
    "        position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)\n",
    "        return position_ids\n",
    "\n",
    "    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):\n",
    "        if not self.supports_gradient_checkpointing:\n",
    "            raise ValueError(f\"{self.__class__.__name__} does not support gradient checkpointing.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16d19827-5201-4981-9618-520275a14b6c",
   "metadata": {},
   "source": [
    "### Embedding 类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "ede82883-45bb-41f8-b679-a97b2d04c87e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Embedding(torch.nn.Module):\n",
    "    \"\"\"语言模型嵌入层。\"\"\"\n",
    "\n",
    "    def __init__(self, config: ChatGLMConfig, device=None):\n",
    "        super(Embedding, self).__init__()\n",
    "\n",
    "        self.hidden_size = config.hidden_size\n",
    "        # 词嵌入层（并行）\n",
    "        self.word_embeddings = nn.Embedding(\n",
    "            config.padded_vocab_size,\n",
    "            self.hidden_size,\n",
    "            dtype=config.torch_dtype,\n",
    "            device=device\n",
    "        )\n",
    "        self.fp32_residual_connection = config.fp32_residual_connection\n",
    "\n",
    "    def forward(self, input_ids):\n",
    "        # 获取词嵌入\n",
    "        words_embeddings = self.word_embeddings(input_ids)\n",
    "        embeddings = words_embeddings\n",
    "        # 如果设置了fp32残差连接，则转换为浮点数\n",
    "        if self.fp32_residual_connection:\n",
    "            embeddings = embeddings.float()\n",
    "        return embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20e7568a-a42f-4053-9939-152a656181e4",
   "metadata": {},
   "source": [
    "### ChatGLMModel 类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "21ed53de-5687-4093-9662-efbc224093b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ChatGLMModel(ChatGLMPreTrainedModel):\n",
    "    def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):\n",
    "        super().__init__(config)\n",
    "        if empty_init:\n",
    "            init_method = skip_init  # 跳过初始化\n",
    "        else:\n",
    "            init_method = default_init  # 使用默认初始化方法\n",
    "        init_kwargs = {}\n",
    "        if device is not None:\n",
    "            init_kwargs[\"device\"] = device\n",
    "        self.embedding = init_method(Embedding, config, **init_kwargs)  # 使用 Embedding 类\n",
    "        self.num_layers = config.num_layers\n",
    "        self.multi_query_group_num = config.multi_query_group_num\n",
    "        self.kv_channels = config.kv_channels\n",
    "\n",
    "        # 旋转位置嵌入\n",
    "        self.seq_length = config.seq_length\n",
    "        rotary_dim = (\n",
    "            config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels\n",
    "        )\n",
    "\n",
    "        self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, original_impl=config.original_rope, \n",
    "                                              device=device, dtype=config.torch_dtype)  # 使用 RotaryEmbedding 类\n",
    "        self.encoder = init_method(GLMTransformer, config, **init_kwargs)  # 使用 GLMTransformer 类\n",
    "        self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,\n",
    "                                        dtype=config.torch_dtype, **init_kwargs)  # 使用 nn.Linear 类\n",
    "\n",
    "    def get_input_embeddings(self):\n",
    "        return self.embedding.word_embeddings\n",
    "\n",
    "    def set_input_embeddings(self, value):\n",
    "        self.embedding.word_embeddings = value\n",
    "\n",
    "    def forward(\n",
    "            self,\n",
    "            input_ids,\n",
    "            position_ids: Optional[torch.Tensor] = None,\n",
    "            attention_mask: Optional[torch.BoolTensor] = None,\n",
    "            full_attention_mask: Optional[torch.BoolTensor] = None,\n",
    "            past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,\n",
    "            inputs_embeds: Optional[torch.Tensor] = None,\n",
    "            use_cache: Optional[bool] = None,\n",
    "            output_hidden_states: Optional[bool] = None,\n",
    "            return_dict: Optional[bool] = None,\n",
    "    ):\n",
    "        output_hidden_states = (\n",
    "            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n",
    "        )\n",
    "        use_cache = use_cache if use_cache is not None else self.config.use_cache\n",
    "        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
    "\n",
    "        batch_size, seq_length = input_ids.shape\n",
    "\n",
    "        if inputs_embeds is None:\n",
    "            inputs_embeds = self.embedding(input_ids)  # 使用 Embedding 类\n",
    "\n",
    "        if full_attention_mask is None:\n",
    "            if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):\n",
    "                full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)\n",
    "\n",
    "        # 旋转位置嵌入\n",
    "        rotary_pos_emb = self.rotary_pos_emb(self.seq_length)  # 使用 RotaryEmbedding 类\n",
    "        if position_ids is not None:\n",
    "            rotary_pos_emb = rotary_pos_emb[position_ids]\n",
    "        else:\n",
    "            rotary_pos_emb = rotary_pos_emb[None, :seq_length]\n",
    "\n",
    "        # 运行编码器\n",
    "        hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(\n",
    "            inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,\n",
    "            kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states\n",
    "        )  # 使用 GLMTransformer 类\n",
    "        if presents is not None and type(presents) is torch.Tensor:\n",
    "            presents = presents.split(1, dim=0)\n",
    "            presents = list(presents)\n",
    "            presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents]\n",
    "            presents = [tuple([x.squeeze(0) for x in y]) for y in presents]\n",
    "            presents = tuple(presents)\n",
    "\n",
    "        if not return_dict:\n",
    "            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)\n",
    "\n",
    "        return BaseModelOutputWithPast(\n",
    "            last_hidden_state=hidden_states,\n",
    "            past_key_values=presents,\n",
    "            hidden_states=all_hidden_states,\n",
    "            attentions=all_self_attentions,\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48b86c96-a071-49aa-8676-f1d8be23a31f",
   "metadata": {},
   "source": [
    "### 总结\n",
    "\n",
    "1. **ChatGLMPreTrainedModel**:\n",
    "   - **get_masks**：生成注意力掩码。\n",
    "   - **get_position_ids**：生成位置 ID。\n",
    "\n",
    "2. **Embedding**:\n",
    "   - **word_embeddings**：实现词嵌入。\n",
    "\n",
    "3. **ChatGLMModel**:\n",
    "   - **RotaryEmbedding**：用于位置编码。\n",
    "   - **GLMTransformer**：实现 Transformer 编码器。\n",
    "   - **forward**：执行前向传播，集成所有模块。\n",
    "\n",
    "4. **使用的先前构建的模块**：\n",
    "   - **RotaryEmbedding**：用于生成旋转位置嵌入。\n",
    "   - **GLMTransformer**：用于编码器部分。\n",
    "   - **Embedding**：用于生成词嵌入。\n",
    "   - **CoreAttention**、**SelfAttention**：间接通过 GLMTransformer 使用。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4af49b32-bd82-4b62-9d95-494ae2d78dfb",
   "metadata": {},
   "source": [
    "### ChatGLMForConditionalGeneration 类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "4725aadc-c0e6-4d42-bebb-a4194dd695f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):\n",
    "    def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):\n",
    "        super().__init__(config)\n",
    "\n",
    "        self.max_sequence_length = config.max_length  # 最大序列长度\n",
    "        self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)  # 使用 ChatGLMModel 类\n",
    "        self.config = config\n",
    "\n",
    "    def _update_model_kwargs_for_generation(\n",
    "            self,\n",
    "            outputs: ModelOutput,\n",
    "            model_kwargs: Dict[str, Any],\n",
    "            is_encoder_decoder: bool = False,\n",
    "            standardize_cache_format: bool = False,\n",
    "    ) -> Dict[str, Any]:\n",
    "        # 更新 past_key_values\n",
    "        model_kwargs[\"past_key_values\"] = self._extract_past_from_model_output(\n",
    "            outputs, standardize_cache_format=standardize_cache_format\n",
    "        )\n",
    "\n",
    "        # 更新注意力掩码\n",
    "        if \"attention_mask\" in model_kwargs:\n",
    "            attention_mask = model_kwargs[\"attention_mask\"]\n",
    "            model_kwargs[\"attention_mask\"] = torch.cat(\n",
    "                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1\n",
    "            )\n",
    "\n",
    "        # 更新位置 ids\n",
    "        if \"position_ids\" in model_kwargs:\n",
    "            position_ids = model_kwargs[\"position_ids\"]\n",
    "            new_position_id = position_ids[..., -1:].clone()\n",
    "            new_position_id += 1\n",
    "            model_kwargs[\"position_ids\"] = torch.cat(\n",
    "                [position_ids, new_position_id], dim=-1\n",
    "            )\n",
    "\n",
    "        model_kwargs[\"is_first_forward\"] = False\n",
    "        return model_kwargs\n",
    "\n",
    "    def prepare_inputs_for_generation(\n",
    "            self,\n",
    "            input_ids: torch.LongTensor,\n",
    "            past_key_values: Optional[torch.Tensor] = None,\n",
    "            attention_mask: Optional[torch.Tensor] = None,\n",
    "            position_ids: Optional[torch.Tensor] = None,\n",
    "            use_cache: Optional[bool] = None,\n",
    "            is_first_forward: bool = True,\n",
    "            **kwargs\n",
    "    ) -> dict:\n",
    "        # 如果 past_key_values 不为空，只取 input_ids 的最后一个 token\n",
    "        if position_ids is None:\n",
    "            position_ids = self.get_position_ids(input_ids, device=input_ids.device)\n",
    "        if not is_first_forward:\n",
    "            if past_key_values is not None:\n",
    "                position_ids = position_ids[..., -1:]\n",
    "                input_ids = input_ids[:, -1:]\n",
    "        return {\n",
    "            \"input_ids\": input_ids,\n",
    "            \"past_key_values\": past_key_values,\n",
    "            \"position_ids\": position_ids,\n",
    "            \"attention_mask\": attention_mask,\n",
    "            \"return_last_logit\": True,\n",
    "            \"use_cache\": use_cache\n",
    "        }\n",
    "\n",
    "    def forward(\n",
    "            self,\n",
    "            input_ids: Optional[torch.Tensor] = None,\n",
    "            position_ids: Optional[torch.Tensor] = None,\n",
    "            attention_mask: Optional[torch.Tensor] = None,\n",
    "            past_key_values: Optional[Tuple[torch.FloatTensor]] = None,\n",
    "            inputs_embeds: Optional[torch.Tensor] = None,\n",
    "            labels: Optional[torch.Tensor] = None,\n",
    "            use_cache: Optional[bool] = None,\n",
    "            output_attentions: Optional[bool] = None,\n",
    "            output_hidden_states: Optional[bool] = None,\n",
    "            return_dict: Optional[bool] = None,\n",
    "            return_last_logit: Optional[bool] = False,\n",
    "    ):\n",
    "        use_cache = use_cache if use_cache is not None else self.config.use_cache\n",
    "        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
    "\n",
    "        transformer_outputs = self.transformer(\n",
    "            input_ids=input_ids,\n",
    "            position_ids=position_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            past_key_values=past_key_values,\n",
    "            inputs_embeds=inputs_embeds,\n",
    "            use_cache=use_cache,\n",
    "            output_hidden_states=output_hidden_states,\n",
    "            return_dict=return_dict,\n",
    "        )  # 使用 ChatGLMModel 类\n",
    "\n",
    "        hidden_states = transformer_outputs[0]\n",
    "        if return_last_logit:\n",
    "            hidden_states = hidden_states[:, -1:]\n",
    "        lm_logits = self.transformer.output_layer(hidden_states)\n",
    "\n",
    "        loss = None\n",
    "        if labels is not None:\n",
    "            lm_logits = lm_logits.to(torch.float32)\n",
    "\n",
    "            # Shift so that tokens < n predict n\n",
    "            shift_logits = lm_logits[..., :-1, :].contiguous()\n",
    "            shift_labels = labels[..., 1:].contiguous()\n",
    "            # Flatten the tokens\n",
    "            loss_fct = CrossEntropyLoss(ignore_index=-100)\n",
    "            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))\n",
    "\n",
    "            lm_logits = lm_logits.to(hidden_states.dtype)\n",
    "            loss = loss.to(hidden_states.dtype)\n",
    "\n",
    "        if not return_dict:\n",
    "            output = (lm_logits,) + transformer_outputs[1:]\n",
    "            return ((loss,) + output) if loss is not None else output\n",
    "\n",
    "        return CausalLMOutputWithPast(\n",
    "            loss=loss,\n",
    "            logits=lm_logits,\n",
    "            past_key_values=transformer_outputs.past_key_values,\n",
    "            hidden_states=transformer_outputs.hidden_states,\n",
    "            attentions=transformer_outputs.attentions,\n",
    "        )\n",
    "\n",
    "    @staticmethod\n",
    "    def _reorder_cache(\n",
    "            past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor\n",
    "    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:\n",
    "        \"\"\"\n",
    "        重新排序 `past_key_values` 缓存以匹配每个生成步骤中的 `beam_idx`。\n",
    "        \"\"\"\n",
    "        return tuple(\n",
    "            (\n",
    "                layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)),\n",
    "                layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)),\n",
    "            )\n",
    "            for layer_past in past\n",
    "        )\n",
    "\n",
    "    def process_response(self, output, history):\n",
    "        content = \"\"\n",
    "        history = deepcopy(history)\n",
    "        for response in output.split(\"\"):\n",
    "            if \"\\n\" in response:\n",
    "                metadata, content = response.split(\"\\n\", maxsplit=1)\n",
    "            else:\n",
    "                metadata, content = \"\", response\n",
    "            if not metadata.strip():\n",
    "                content = content.strip()\n",
    "                history.append({\"role\": \"assistant\", \"metadata\": metadata, \"content\": content})\n",
    "                content = content.replace(\"[[训练时间]]\", \"2023年\")\n",
    "            else:\n",
    "                history.append({\"role\": \"assistant\", \"metadata\": metadata, \"content\": content})\n",
    "                if history[0][\"role\"] == \"system\" and \"tools\" in history[0]:\n",
    "                    parameters = json.loads(content)\n",
    "                    content = {\"name\": metadata.strip(), \"parameters\": parameters}\n",
    "                else:\n",
    "                    content = {\"name\": metadata.strip(), \"content\": content}\n",
    "        return content, history\n",
    "\n",
    "    @torch.inference_mode()\n",
    "    def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = \"user\",\n",
    "             max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,\n",
    "             **kwargs):\n",
    "        if history is None:\n",
    "            history = []\n",
    "        if logits_processor is None:\n",
    "            logits_processor = LogitsProcessorList()\n",
    "        logits_processor.append(InvalidScoreLogitsProcessor())\n",
    "        gen_kwargs = {\"max_length\": max_length, \"num_beams\": num_beams, \"do_sample\": do_sample, \"top_p\": top_p,\n",
    "                      \"temperature\": temperature, \"logits_processor\": logits_processor, **kwargs}\n",
    "        history.append({\"role\": role, \"content\": query})\n",
    "        inputs = tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=True,\n",
    "                                               return_tensors=\"pt\", return_dict=True)\n",
    "        inputs = inputs.to(self.device)\n",
    "        eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"\"),\n",
    "                        tokenizer.convert_tokens_to_ids(\"\")]\n",
    "        outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)\n",
    "        outputs = outputs.tolist()[0][len(inputs[\"input_ids\"][0]):-1]\n",
    "        response = tokenizer.decode(outputs)\n",
    "        response, history = self.process_response(response, history)\n",
    "        return response, history\n",
    "\n",
    "    @torch.inference_mode()\n",
    "    def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = \"user\",\n",
    "                    past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,\n",
    "                    logits_processor=None, return_past_key_values=False, **kwargs):\n",
    "        if history is None:\n",
    "            history = []\n",
    "        if logits_processor is None:\n",
    "            logits_processor = LogitsProcessorList()\n",
    "        logits_processor.append(InvalidScoreLogitsProcessor())\n",
    "        eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"\"),\n",
    "                        tokenizer.convert_tokens_to_ids(\"\")]\n",
    "        gen_kwargs = {\"max_length\": max_length, \"do_sample\": do_sample, \"top_p\": top_p,\n",
    "                      \"temperature\": temperature, \"logits_processor\": logits_processor, **kwargs}\n",
    "        if past_key_values is None:\n",
    "            inputs = tokenizer.apply_chat_template(history + [{\"role\": role, \"content\": query}],\n",
    "                                                   add_generation_prompt=True, tokenize=True, return_tensors=\"pt\",\n",
    "                                                   return_dict=True)\n",
    "        else:\n",
    "            inputs = tokenizer.apply_chat_template([{\"role\": role, \"content\": query}], add_special_tokens=False,\n",
    "                                                   add_generation_prompt=True, tokenize=True, return_tensors=\"pt\",\n",
    "                                                   return_dict=True)\n",
    "        inputs = inputs.to(self.device)\n",
    "        if past_key_values is not None:\n",
    "            past_length = past_key_values[0][0].shape[2]\n",
    "            inputs.position_ids += past_length\n",
    "\n",
    "\n",
    "            attention_mask = inputs.attention_mask\n",
    "            attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)\n",
    "            inputs['attention_mask'] = attention_mask\n",
    "        history.append({\"role\": role, \"content\": query})\n",
    "        for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,\n",
    "                                            eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,\n",
    "                                            **gen_kwargs):\n",
    "            if return_past_key_values:\n",
    "                outputs, past_key_values = outputs\n",
    "            outputs = outputs.tolist()[0][len(inputs[\"input_ids\"][0]):-1]\n",
    "            response = tokenizer.decode(outputs)\n",
    "            if response and response[-1] != \"�\":\n",
    "                response, new_history = self.process_response(response, history)\n",
    "                if return_past_key_values:\n",
    "                    yield response, new_history, past_key_values\n",
    "                else:\n",
    "                    yield response, new_history\n",
    "\n",
    "    @torch.inference_mode()\n",
    "    def stream_generate(\n",
    "            self,\n",
    "            input_ids,\n",
    "            generation_config: Optional[GenerationConfig] = None,\n",
    "            logits_processor: Optional[LogitsProcessorList] = None,\n",
    "            stopping_criteria: Optional[StoppingCriteriaList] = None,\n",
    "            prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,\n",
    "            return_past_key_values=False,\n",
    "            **kwargs,\n",
    "    ):\n",
    "        batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]\n",
    "\n",
    "        if generation_config is None:\n",
    "            generation_config = self.generation_config\n",
    "        generation_config = copy.deepcopy(generation_config)\n",
    "        model_kwargs = generation_config.update(**kwargs)\n",
    "        model_kwargs[\"use_cache\"] = generation_config.use_cache\n",
    "        bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id\n",
    "\n",
    "        if isinstance(eos_token_id, int):\n",
    "            eos_token_id = [eos_token_id]\n",
    "        eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None\n",
    "\n",
    "        has_default_max_length = kwargs.get(\"max_length\") is None and generation_config.max_length is not None\n",
    "        if has_default_max_length and generation_config.max_new_tokens is None:\n",
    "            warnings.warn(\n",
    "                f\"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. \"\n",
    "                \"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we\"\n",
    "                \" recommend using `max_new_tokens` to control the maximum length of the generation.\",\n",
    "                UserWarning,\n",
    "            )\n",
    "        elif generation_config.max_new_tokens is not None:\n",
    "            generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length\n",
    "            if not has_default_max_length:\n",
    "                logger.warn(\n",
    "                    f\"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=\"\n",
    "                    f\"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. \"\n",
    "                    \"Please refer to the documentation for more information. \"\n",
    "                    \"(https://hf-mirror.com/docs/transformers/main/en/main_classes/text_generation)\",\n",
    "                    UserWarning,\n",
    "                )\n",
    "\n",
    "        if input_ids_seq_length >= generation_config.max_length:\n",
    "            input_ids_string = \"decoder_input_ids\" if self.config.is_encoder_decoder else \"input_ids\"\n",
    "            logger.warning(\n",
    "                f\"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to\"\n",
    "                f\" {generation_config.max_length}. This can lead to unexpected behavior. You should consider\"\n",
    "                \" increasing `max_new_tokens`.\"\n",
    "            )\n",
    "\n",
    "        # 2. Set generation parameters if not already defined\n",
    "        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()\n",
    "        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()\n",
    "\n",
    "        logits_processor = self._get_logits_processor(\n",
    "            generation_config=generation_config,\n",
    "            input_ids_seq_length=input_ids_seq_length,\n",
    "            encoder_input_ids=input_ids,\n",
    "            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,\n",
    "            logits_processor=logits_processor,\n",
    "        )\n",
    "\n",
    "        stopping_criteria = self._get_stopping_criteria(\n",
    "            generation_config=generation_config, stopping_criteria=stopping_criteria\n",
    "        )\n",
    "        logits_warper = self._get_logits_warper(generation_config)\n",
    "\n",
    "        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)\n",
    "        scores = None\n",
    "        while True:\n",
    "            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)\n",
    "            # 前向传递获取下一个 token\n",
    "            outputs = self(\n",
    "                **model_inputs,\n",
    "                return_dict=True,\n",
    "                output_attentions=False,\n",
    "                output_hidden_states=False,\n",
    "            )\n",
    "\n",
    "            next_token_logits = outputs.logits[:, -1, :]\n",
    "\n",
    "            # 预处理分布\n",
    "            next_token_scores = logits_processor(input_ids, next_token_logits)\n",
    "            next_token_scores = logits_warper(input_ids, next_token_scores)\n",
    "\n",
    "            # 采样\n",
    "            probs = nn.functional.softmax(next_token_scores, dim=-1)\n",
    "            if generation_config.do_sample:\n",
    "                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)\n",
    "            else:\n",
    "                next_tokens = torch.argmax(probs, dim=-1)\n",
    "            # 更新生成的 ids、模型输入和下一个步骤的长度\n",
    "            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)\n",
    "            model_kwargs = self._update_model_kwargs_for_generation(\n",
    "                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder\n",
    "            )\n",
    "            unfinished_sequences = unfinished_sequences.mul(\n",
    "                next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)\n",
    "            )\n",
    "            if return_past_key_values:\n",
    "                yield input_ids, outputs.past_key_values\n",
    "            else:\n",
    "                yield input_ids\n",
    "            # 当每个句子完成时或超出最大长度时停止\n",
    "            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):\n",
    "                break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20a7db6f-8889-4e73-b882-edf7e4c72ca7",
   "metadata": {},
   "source": [
    "### ChatGLMForSequenceClassification 类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "832d1a09-3f8c-4035-ad73-ec3433a71f50",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):\n",
    "    def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):\n",
    "        super().__init__(config)\n",
    "\n",
    "        self.num_labels = config.num_labels  # 标签数量\n",
    "        self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)  # 使用 ChatGLMModel 类\n",
    "\n",
    "        self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)\n",
    "        if config.classifier_dropout is not None:\n",
    "            self.dropout = nn.Dropout(config.classifier_dropout)\n",
    "        else:\n",
    "            self.dropout = None\n",
    "        self.config = config\n",
    "\n",
    "    def forward(\n",
    "            self,\n",
    "            input_ids: Optional[torch.LongTensor] = None,\n",
    "            position_ids: Optional[torch.LongTensor] = None,\n",
    "            attention_mask: Optional[torch.Tensor] = None,\n",
    "            full_attention_mask: Optional[torch.Tensor] = None,\n",
    "            past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,\n",
    "            inputs_embeds: Optional[torch.LongTensor] = None,\n",
    "            labels: Optional[torch.LongTensor] = None,\n",
    "            use_cache: Optional[bool] = None,\n",
    "            output_hidden_states: Optional[bool] = None,\n",
    "            return_dict: Optional[bool] = None,\n",
    "    ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:\n",
    "        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
    "\n",
    "        transformer_outputs = self.transformer(\n",
    "            input_ids=input_ids,\n",
    "            position_ids=position_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            full_attention_mask=full_attention_mask,\n",
    "            past_key_values=past_key_values,\n",
    "            inputs_embeds=inputs_embeds,\n",
    "            use_cache=use_cache,\n",
    "            output_hidden_states=output_hidden_states,\n",
    "            return_dict=return_dict,\n",
    "        )  # 使用 ChatGLMModel 类\n",
    "\n",
    "        hidden_states = transformer_outputs[0]\n",
    "        pooled_hidden_states = hidden_states[:, -1]\n",
    "        if self.dropout is not None:\n",
    "            pooled_hidden_states = self.dropout(pooled_hidden_states)\n",
    "        logits = self.classifier_head(pooled_hidden_states)\n",
    "\n",
    "        loss = None\n",
    "        if labels is not None:\n",
    "            if self.config.problem_type is None:\n",
    "                if self.num_labels == 1:\n",
    "                    self.config.problem_type = \"regression\"\n",
    "                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n",
    "                    self.config.problem_type = \"single_label_classification\"\n",
    "                else:\n",
    "                    self.config.problem_type = \"multi_label_classification\"\n",
    "\n",
    "            if self.config.problem_type == \"regression\":\n",
    "                loss_fct = MSELoss()\n",
    "                if self.num_labels == 1:\n",
    "                    loss = loss_fct(logits.squeeze().float(), labels.squeeze())\n",
    "                else:\n",
    "                    loss = loss_fct(logits.float(), labels)\n",
    "            elif self.config.problem_type == \"single_label_classification\":\n",
    "                loss_fct = CrossEntropyLoss()\n",
    "                loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))\n",
    "            elif self.config.problem_type == \"multi_label_classification\":\n",
    "                loss_fct = BCEWithLogitsLoss()\n",
    "                loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))\n",
    "\n",
    "        if not return_dict:\n",
    "            output = (logits,) + transformer_outputs[1:]\n",
    "            return ((loss,) + output) if loss is not None else output\n",
    "\n",
    "        return SequenceClassifierOutputWithPast(\n",
    "            loss\n",
    "\n",
    "=loss,\n",
    "            logits=logits,\n",
    "            past_key_values=transformer_outputs.past_key_values,\n",
    "            hidden_states=transformer_outputs.hidden_states,\n",
    "            attentions=transformer_outputs.attentions,\n",
    "        )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "415020b6-a65c-4738-aaac-c25b00c84996",
   "metadata": {},
   "source": [
    "### 使用的先前构建的模块\n",
    "\n",
    "在这段代码中，多个模块和方法是基于之前构建的类和函数的：\n",
    "\n",
    "1. **ChatGLMModel**：\n",
    "   - **用于 ChatGLMForConditionalGeneration 和 ChatGLMForSequenceClassification 中**，作为 Transformer 模型的核心部分。\n",
    "\n",
    "2. **ChatGLMPreTrainedModel**：\n",
    "   - **作为 ChatGLMForConditionalGeneration 和 ChatGLMForSequenceClassification 的基类**，提供权重初始化和加载预训练模型的接口。\n",
    "\n",
    "3. **RotaryEmbedding**：\n",
    "   - **在 ChatGLMModel 中用于位置编码**。\n",
    "\n",
    "4. **CoreAttention 和 SelfAttention**：\n",
    "   - **在 GLMTransformer 中使用**，实现了注意力机制的核心部分。\n",
    "\n",
    "通过详细注释和说明，可以更好地理解代码的构建和实现原理。这些模块共同构成了 ChatGLM 模型的整体架构，实现了条件生成和序列分类的功能。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d43e2c5d-40de-4c9d-a8b7-e49099306fa7",
   "metadata": {},
   "source": [
    "### `past_key_values` 变量的含义\n",
    "\n",
    "在 Transformer 模型中，特别是用于生成任务的模型，如 GPT 类模型中，`past_key_values` 是一个非常重要的变量。它用于缓存模型在前一个时间步计算得到的键（key）和值（value）向量。这些缓存的数据可以在后续的时间步中重复使用，从而提高计算效率，尤其是在长序列生成任务中。\n",
    "\n",
    "### `past_key_values` 的作用\n",
    "\n",
    "1. **缓存先前计算结果**：\n",
    "   在生成文本的过程中，每一步生成一个新的词，这时需要将当前时间步的查询向量（query）与所有先前时间步的键和值向量进行计算。如果每次都重新计算所有的键和值，将会非常低效。`past_key_values` 缓存了这些先前时间步的结果，避免了重复计算。\n",
    "\n",
    "2. **加速生成过程**：\n",
    "   在长序列生成中，通过缓存先前时间步的键和值向量，只需要对当前时间步进行计算并与缓存结果结合，大大加速了生成过程。\n",
    "\n",
    "### 具体实现中的 `past_key_values`\n",
    "\n",
    "#### 在 `ChatGLMForConditionalGeneration` 类中的使用\n",
    "\n",
    "```python\n",
    "def prepare_inputs_for_generation(\n",
    "        self,\n",
    "        input_ids: torch.LongTensor,\n",
    "        past_key_values: Optional[torch.Tensor] = None,\n",
    "        attention_mask: Optional[torch.Tensor] = None,\n",
    "        position_ids: Optional[torch.Tensor] = None,\n",
    "        use_cache: Optional[bool] = None,\n",
    "        is_first_forward: bool = True,\n",
    "        **kwargs\n",
    ") -> dict:\n",
    "    # 如果 past_key_values 不为空，只取 input_ids 的最后一个 token\n",
    "    if position_ids is None:\n",
    "        position_ids = self.get_position_ids(input_ids, device=input_ids.device)\n",
    "    if not is_first_forward:\n",
    "        if past_key_values is not None:\n",
    "            position_ids = position_ids[..., -1:]\n",
    "            input_ids = input_ids[:, -1:]\n",
    "    return {\n",
    "        \"input_ids\": input_ids,\n",
    "        \"past_key_values\": past_key_values,\n",
    "        \"position_ids\": position_ids,\n",
    "        \"attention_mask\": attention_mask,\n",
    "        \"return_last_logit\": True,\n",
    "        \"use_cache\": use_cache\n",
    "    }\n",
    "```\n",
    "\n",
    "在 `prepare_inputs_for_generation` 方法中，如果 `past_key_values` 不为空，只会取 `input_ids` 的最后一个 token。这样做的目的是为了在生成新 token 时，只计算当前时间步的数据，而不需要重新计算整个序列。\n",
    "\n",
    "#### 在 `ChatGLMForConditionalGeneration` 类的 `forward` 方法中\n",
    "\n",
    "```python\n",
    "def forward(\n",
    "        self,\n",
    "        input_ids: Optional[torch.Tensor] = None,\n",
    "        position_ids: Optional[torch.Tensor] = None,\n",
    "        attention_mask: Optional[torch.Tensor] = None,\n",
    "        past_key_values: Optional[Tuple[torch.FloatTensor]] = None,\n",
    "        inputs_embeds: Optional[torch.Tensor] = None,\n",
    "        labels: Optional[torch.Tensor] = None,\n",
    "        use_cache: Optional[bool] = None,\n",
    "        output_attentions: Optional[bool] = None,\n",
    "        output_hidden_states: Optional[bool] = None,\n",
    "        return_dict: Optional[bool] = None,\n",
    "        return_last_logit: Optional[bool] = False,\n",
    "):\n",
    "    use_cache = use_cache if use_cache is not None else self.config.use_cache\n",
    "    return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
    "\n",
    "    transformer_outputs = self.transformer(\n",
    "        input_ids=input_ids,\n",
    "        position_ids=position_ids,\n",
    "        attention_mask=attention_mask,\n",
    "        past_key_values=past_key_values,\n",
    "        inputs_embeds=inputs_embeds,\n",
    "        use_cache=use_cache,\n",
    "        output_hidden_states=output_hidden_states,\n",
    "        return_dict=return_dict,\n",
    "    )  # 使用 ChatGLMModel 类\n",
    "\n",
    "    hidden_states = transformer_outputs[0]\n",
    "    if return_last_logit:\n",
    "        hidden_states = hidden_states[:, -1:]\n",
    "    lm_logits = self.transformer.output_layer(hidden_states)\n",
    "\n",
    "    loss = None\n",
    "    if labels is not None:\n",
    "        lm_logits = lm_logits.to(torch.float32)\n",
    "\n",
    "        # Shift so that tokens < n predict n\n",
    "        shift_logits = lm_logits[..., :-1, :].contiguous()\n",
    "        shift_labels = labels[..., 1:].contiguous()\n",
    "        # Flatten the tokens\n",
    "        loss_fct = CrossEntropyLoss(ignore_index=-100)\n",
    "        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))\n",
    "\n",
    "        lm_logits = lm_logits.to(hidden_states.dtype)\n",
    "        loss = loss.to(hidden_states.dtype)\n",
    "\n",
    "    if not return_dict:\n",
    "        output = (lm_logits,) + transformer_outputs[1:]\n",
    "        return ((loss,) + output) if loss is not None else output\n",
    "\n",
    "    return CausalLMOutputWithPast(\n",
    "        loss=loss,\n",
    "        logits=lm_logits,\n",
    "        past_key_values=transformer_outputs.past_key_values,\n",
    "        hidden_states=transformer_outputs.hidden_states,\n",
    "        attentions=transformer_outputs.attentions,\n",
    "    )\n",
    "```\n",
    "\n",
    "在 `forward` 方法中，`past_key_values` 作为参数传递给 `transformer` 模型。`transformer` 模型内部会使用这些缓存的键和值向量来加速计算。\n",
    "`past_key_values` 是一种用于加速 Transformer 模型在生成任务中的缓存机制。它保存了前一个时间步计算得到的键和值向量，避免了在每个时间步中重复计算这些向量，从而提高了生成过程的效率。通过使用 `past_key_values`，模型可以更快地生成长序列数据，这在实际应用中是非常重要的。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0e361aa-a1b5-4eae-9178-306b888713f4",
   "metadata": {},
   "source": [
    "### 函数分析及其区别和联系\n",
    "\n",
    "在 `ChatGLMForConditionalGeneration` 类中，有三个与生成有关的重要函数：`_update_model_kwargs_for_generation`、`prepare_inputs_for_generation` 和 `forward`。它们的作用、区别和联系如下：\n",
    "\n",
    "#### 1. `_update_model_kwargs_for_generation` 函数\n",
    "\n",
    "```python\n",
    "def _update_model_kwargs_for_generation(\n",
    "        self,\n",
    "        outputs: ModelOutput,\n",
    "        model_kwargs: Dict[str, Any],\n",
    "        is_encoder_decoder: bool = False,\n",
    "        standardize_cache_format: bool = False,\n",
    ") -> Dict[str, Any]:\n",
    "    # 更新 past_key_values\n",
    "    model_kwargs[\"past_key_values\"] = self._extract_past_from_model_output(\n",
    "        outputs, standardize_cache_format=standardize_cache_format\n",
    "    )\n",
    "\n",
    "    # 更新注意力掩码\n",
    "    if \"attention_mask\" in model_kwargs:\n",
    "        attention_mask = model_kwargs[\"attention_mask\"]\n",
    "        model_kwargs[\"attention_mask\"] = torch.cat(\n",
    "            [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1\n",
    "        )\n",
    "\n",
    "    # 更新位置 ids\n",
    "    if \"position_ids\" in model_kwargs:\n",
    "        position_ids = model_kwargs[\"position_ids\"]\n",
    "        new_position_id = position_ids[..., -1:].clone()\n",
    "        new_position_id += 1\n",
    "        model_kwargs[\"position_ids\"] = torch.cat(\n",
    "            [position_ids, new_position_id], dim=-1\n",
    "        )\n",
    "\n",
    "    model_kwargs[\"is_first_forward\"] = False\n",
    "    return model_kwargs\n",
    "```\n",
    "\n",
    "- **作用**：更新生成过程中所需的模型参数。具体包括：\n",
    "  - 更新 `past_key_values` 以缓存先前计算的键和值向量。\n",
    "  - 更新 `attention_mask` 以包括新的生成的 token。\n",
    "  - 更新 `position_ids` 以增加新的位置 ID。\n",
    "- **区别**：该函数不直接进行前向传播，而是更新模型参数，为下一步的生成做准备。\n",
    "- **联系**：该函数在每一步生成新 token 后调用，用于更新模型参数，为下一步的生成做准备。\n",
    "\n",
    "#### 2. `prepare_inputs_for_generation` 函数\n",
    "\n",
    "```python\n",
    "def prepare_inputs_for_generation(\n",
    "        self,\n",
    "        input_ids: torch.LongTensor,\n",
    "        past_key_values: Optional[torch.Tensor] = None,\n",
    "        attention_mask: Optional[torch.Tensor] = None,\n",
    "        position_ids: Optional[torch.Tensor] = None,\n",
    "        use_cache: Optional[bool] = None,\n",
    "        is_first_forward: bool = True,\n",
    "        **kwargs\n",
    ") -> dict:\n",
    "    # 如果 past_key_values 不为空，只取 input_ids 的最后一个 token\n",
    "    if position_ids is None:\n",
    "        position_ids = self.get_position_ids(input_ids, device=input_ids.device)\n",
    "    if not is_first_forward:\n",
    "        if past_key_values is not None:\n",
    "            position_ids = position_ids[..., -1:]\n",
    "            input_ids = input_ids[:, -1:]\n",
    "    return {\n",
    "        \"input_ids\": input_ids,\n",
    "        \"past_key_values\": past_key_values,\n",
    "        \"position_ids\": position_ids,\n",
    "        \"attention_mask\": attention_mask,\n",
    "        \"return_last_logit\": True,\n",
    "        \"use_cache\": use_cache\n",
    "    }\n",
    "```\n",
    "\n",
    "- **作用**：准备生成过程所需的输入。具体包括：\n",
    "  - 获取或更新 `position_ids`。\n",
    "  - 如果不是第一次前向传播，且存在 `past_key_values`，则只取 `input_ids` 和 `position_ids` 的最后一个 token。\n",
    "- **区别**：该函数主要用于处理输入数据，确保输入数据的形状和内容适合当前生成步骤。\n",
    "- **联系**：在每一步生成过程中，会调用该函数准备输入数据，尤其是处理 `past_key_values` 以提高生成效率。\n",
    "\n",
    "#### 3. `forward` 函数\n",
    "\n",
    "```python\n",
    "def forward(\n",
    "        self,\n",
    "        input_ids: Optional[torch.Tensor] = None,\n",
    "        position_ids: Optional[torch.Tensor] = None,\n",
    "        attention_mask: Optional[torch.Tensor] = None,\n",
    "        past_key_values: Optional[Tuple[torch.FloatTensor]] = None,\n",
    "        inputs_embeds: Optional[torch.Tensor] = None,\n",
    "        labels: Optional[torch.Tensor] = None,\n",
    "        use_cache: Optional[bool] = None,\n",
    "        output_attentions: Optional[bool] = None,\n",
    "        output_hidden_states: Optional[bool] = None,\n",
    "        return_dict: Optional[bool] = None,\n",
    "        return_last_logit: Optional[bool] = False,\n",
    "):\n",
    "    use_cache = use_cache if use_cache is not None else self.config.use_cache\n",
    "    return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
    "\n",
    "    transformer_outputs = self.transformer(\n",
    "        input_ids=input_ids,\n",
    "        position_ids=position_ids,\n",
    "        attention_mask=attention_mask,\n",
    "        past_key_values=past_key_values,\n",
    "        inputs_embeds=inputs_embeds,\n",
    "        use_cache=use_cache,\n",
    "        output_hidden_states=output_hidden_states,\n",
    "        return_dict=return_dict,\n",
    "    )  # 使用 ChatGLMModel 类\n",
    "\n",
    "    hidden_states = transformer_outputs[0]\n",
    "    if return_last_logit:\n",
    "        hidden_states = hidden_states[:, -1:]\n",
    "    lm_logits = self.transformer.output_layer(hidden_states)\n",
    "\n",
    "    loss = None\n",
    "    if labels is not None:\n",
    "        lm_logits = lm_logits.to(torch.float32)\n",
    "\n",
    "        # Shift so that tokens < n predict n\n",
    "        shift_logits = lm_logits[..., :-1, :].contiguous()\n",
    "        shift_labels = labels[..., 1:].contiguous()\n",
    "        # Flatten the tokens\n",
    "        loss_fct = CrossEntropyLoss(ignore_index=-100)\n",
    "        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))\n",
    "\n",
    "        lm_logits = lm_logits.to(hidden_states.dtype)\n",
    "        loss = loss.to(hidden_states.dtype)\n",
    "\n",
    "    if not return_dict:\n",
    "        output = (lm_logits,) + transformer_outputs[1:]\n",
    "        return ((loss,) + output) if loss is not None else output\n",
    "\n",
    "    return CausalLMOutputWithPast(\n",
    "        loss=loss,\n",
    "        logits=lm_logits,\n",
    "        past_key_values=transformer_outputs.past_key_values,\n",
    "        hidden_states=transformer_outputs.hidden_states,\n",
    "        attentions=transformer_outputs.attentions,\n",
    "    )\n",
    "```\n",
    "\n",
    "- **作用**：执行前向传播，生成模型的输出。具体包括：\n",
    "  - 将输入数据传递给 `transformer`（`ChatGLMModel`），进行前向计算。\n",
    "  - 计算语言模型的 logits 和（如果有标签）计算损失。\n",
    "  - 返回模型输出，包括 logits、`past_key_values`、隐藏状态和注意力权重。\n",
    "- **区别**：这是模型的核心前向传播逻辑，直接处理输入数据并生成输出。\n",
    "- **联系**：`forward` 函数使用了 `ChatGLMModel` 类来进行实际的前向传播，并调用了之前的 `prepare_inputs_for_generation` 来准备输入。\n",
    "\n",
    "### 联系和流程\n",
    "\n",
    "1. **准备输入数据**：\n",
    "   - `prepare_inputs_for_generation` 函数用于处理输入数据，尤其是处理 `past_key_values` 以便只传递必要的最后一个 token。\n",
    "   \n",
    "2. **执行前向传播**：\n",
    "   - `forward` 函数使用准备好的输入数据进行前向传播，生成输出。\n",
    "\n",
    "3. **更新模型参数**：\n",
    "   - `_update_model_kwargs_for_generation` 函数在每一步生成之后，更新模型的关键参数（如 `past_key_values`、`attention_mask` 和 `position_ids`），确保在下一步生成中使用最新的数据。\n",
    "\n",
    "通过这些函数的紧密配合，可以高效地实现生成任务中的前向传播和缓存管理，从而提高模型的生成效率和效果。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1bfbd264-47bc-43bf-853a-4470462e31cd",
   "metadata": {},
   "source": [
    "这三个函数是 `ChatGLMForConditionalGeneration` 类中的核心函数，分别用于处理不同的生成任务需求。它们之间有一定的联系，同时也有各自的用途和特点。以下是对它们的详细解释及其区别和联系：\n",
    "\n",
    "### 1. `chat` 函数\n",
    "\n",
    "```python\n",
    "@torch.inference_mode()\n",
    "def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = \"user\",\n",
    "         max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,\n",
    "         **kwargs):\n",
    "    if history is None:\n",
    "        history = []\n",
    "    if logits_processor is None:\n",
    "        logits_processor = LogitsProcessorList()\n",
    "    logits_processor.append(InvalidScoreLogitsProcessor())\n",
    "    gen_kwargs = {\"max_length\": max_length, \"num_beams\": num_beams, \"do_sample\": do_sample, \"top_p\": top_p,\n",
    "                  \"temperature\": temperature, \"logits_processor\": logits_processor, **kwargs}\n",
    "    history.append({\"role\": role, \"content\": query})\n",
    "    inputs = tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=True,\n",
    "                                           return_tensors=\"pt\", return_dict=True)\n",
    "    inputs = inputs.to(self.device)\n",
    "    eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"\"), tokenizer.convert_tokens_to_ids(\"\")]\n",
    "    outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)\n",
    "    outputs = outputs.tolist()[0][len(inputs[\"input_ids\"][0]):-1]\n",
    "    response = tokenizer.decode(outputs)\n",
    "    response, history = self.process_response(response, history)\n",
    "    return response, history\n",
    "```\n",
    "\n",
    "- **作用**：执行一次完整的聊天会话。将用户的查询和历史记录编码成模型输入，生成响应并更新历史记录。\n",
    "- **区别**：这是一个高层次的接口，适用于一次性生成完整响应。适合用于需要立即获得完整回答的应用场景。\n",
    "- **联系**：它依赖于 `generate` 函数来实际生成响应，并调用 `process_response` 函数来处理生成的输出。\n",
    "\n",
    "### 2. `stream_chat` 函数\n",
    "\n",
    "```python\n",
    "@torch.inference_mode()\n",
    "def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = \"user\",\n",
    "                past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,\n",
    "                logits_processor=None, return_past_key_values=False, **kwargs):\n",
    "    if history is None:\n",
    "        history = []\n",
    "    if logits_processor is None:\n",
    "        logits_processor = LogitsProcessorList()\n",
    "    logits_processor.append(InvalidScoreLogitsProcessor())\n",
    "    eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"\"), tokenizer.convert_tokens_to_ids(\"\")]\n",
    "    gen_kwargs = {\"max_length\": max_length, \"do_sample\": do_sample, \"top_p\": top_p,\n",
    "                  \"temperature\": temperature, \"logits_processor\": logits_processor, **kwargs}\n",
    "    if past_key_values is None:\n",
    "        inputs = tokenizer.apply_chat_template(history + [{\"role\": role, \"content\": query}],\n",
    "                                               add_generation_prompt=True, tokenize=True, return_tensors=\"pt\",\n",
    "                                               return_dict=True)\n",
    "    else:\n",
    "        inputs = tokenizer.apply_chat_template([{\"role\": role, \"content\": query}], add_special_tokens=False,\n",
    "                                               add_generation_prompt=True, tokenize=True, return_tensors=\"pt\",\n",
    "                                               return_dict=True)\n",
    "    inputs = inputs.to(self.device)\n",
    "    if past_key_values is not None:\n",
    "        past_length = past_key_values[0][0].shape[2]\n",
    "        inputs.position_ids += past_length\n",
    "        attention_mask = inputs.attention_mask\n",
    "        attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)\n",
    "        inputs['attention_mask'] = attention_mask\n",
    "    history.append({\"role\": role, \"content\": query})\n",
    "    for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,\n",
    "                                        eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,\n",
    "                                        **gen_kwargs):\n",
    "        if return_past_key_values:\n",
    "            outputs, past_key_values = outputs\n",
    "        outputs = outputs.tolist()[0][len(inputs[\"input_ids\"][0]):-1]\n",
    "        response = tokenizer.decode(outputs)\n",
    "        if response and response[-1] != \"�\":\n",
    "            response, new_history = self.process_response(response, history)\n",
    "            if return_past_key_values:\n",
    "                yield response, new_history, past_key_values\n",
    "            else:\n",
    "                yield response, new_history\n",
    "```\n",
    "\n",
    "- **作用**：实现流式聊天会话。与 `chat` 函数类似，但它通过生成器逐步返回响应，适用于流式生成应用场景。\n",
    "- **区别**：支持逐步生成响应，使得可以在生成过程中动态处理和显示部分响应。\n",
    "- **联系**：依赖于 `stream_generate` 函数来逐步生成响应，并在每次生成新的响应片段后调用 `process_response` 函数来处理和更新历史记录。\n",
    "\n",
    "### 3. `stream_generate` 函数\n",
    "\n",
    "```python\n",
    "@torch.inference_mode()\n",
    "def stream_generate(\n",
    "        self,\n",
    "        input_ids,\n",
    "        generation_config: Optional[GenerationConfig] = None,\n",
    "        logits_processor: Optional[LogitsProcessorList] = None,\n",
    "        stopping_criteria: Optional[StoppingCriteriaList] = None,\n",
    "        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,\n",
    "        return_past_key_values=False,\n",
    "        **kwargs,\n",
    "):\n",
    "    batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]\n",
    "\n",
    "    if generation_config is None:\n",
    "        generation_config = self.generation_config\n",
    "    generation_config = copy.deepcopy(generation_config)\n",
    "    model_kwargs = generation_config.update(**kwargs)\n",
    "    model_kwargs[\"use_cache\"] = generation_config.use_cache\n",
    "    bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id\n",
    "\n",
    "    if isinstance(eos_token_id, int):\n",
    "        eos_token_id = [eos_token_id]\n",
    "    eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None\n",
    "\n",
    "    has_default_max_length = kwargs.get(\"max_length\") is None and generation_config.max_length is not None\n",
    "    if has_default_max_length and generation_config.max_new_tokens is None:\n",
    "        warnings.warn(\n",
    "            f\"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. \"\n",
    "            \"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we\"\n",
    "            \" recommend using `max_new_tokens` to control the maximum length of the generation.\",\n",
    "            UserWarning,\n",
    "        )\n",
    "    elif generation_config.max_new_tokens is not None:\n",
    "        generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length\n",
    "        if not has_default_max_length:\n",
    "            logger.warn(\n",
    "                f\"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=\"\n",
    "                f\"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. \"\n",
    "                \"Please refer to the documentation for more information. \"\n",
    "                \"(https://hf-mirror.com/docs/transformers/main/en/main_classes/text_generation)\",\n",
    "                UserWarning,\n",
    "            )\n",
    "\n",
    "    if input_ids_seq_length >= generation_config.max_length:\n",
    "        input_ids_string = \"decoder_input_ids\" if self.config.is_encoder_decoder else \"input_ids\"\n",
    "        logger.warning(\n",
    "            f\"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to\"\n",
    "            f\" {generation_config.max_length}. This can lead to unexpected behavior. You should consider\"\n",
    "            \" increasing `max_new_tokens`.\"\n",
    "        )\n",
    "\n",
    "    # 2. Set generation parameters if not already defined\n",
    "    logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()\n",
    "    stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()\n",
    "\n",
    "    logits_processor = self._get_logits_processor(\n",
    "        generation_config=generation_config,\n",
    "        input_ids_seq_length=input_ids_seq_length,\n",
    "        encoder_input_ids=input_ids,\n",
    "        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,\n",
    "        logits_processor=logits_processor,\n",
    "    )\n",
    "\n",
    "    stopping_criteria = self._get_stopping_criteria(\n",
    "        generation_config=generation_config, stopping_criteria=stopping_criteria\n",
    "    )\n",
    "    logits_warper = self._get_logits_warper(generation_config)\n",
    "\n",
    "    unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)\n",
    "    scores = None\n",
    "    while True:\n",
    "        model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)\n",
    "        # 前向传递获取下一个 token\n",
    "        outputs = self(\n",
    "            **model_inputs,\n",
    "            return_dict=True,\n",
    "            output_attentions=False,\n",
    "            output_hidden_states=False,\n",
    "        )\n",
    "\n",
    "        next_token_logits = outputs.logits[:, -1, :]\n",
    "\n",
    "        # 预处理分布\n",
    "        next_token_scores = logits_processor(input_ids, next_token_logits)\n",
    "        next_token_scores = logits_warper(input_ids, next_token_scores)\n",
    "\n",
    "        # 采样\n",
    "        probs = nn.functional.softmax(next_token_scores, dim=-1)\n",
    "        if generation_config.do_sample:\n",
    "            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)\n",
    "        else:\n",
    "            next_tokens = torch.argmax(probs, dim=-1)\n",
    "       \n",
    "\n",
    " # 更新生成的 ids、模型输入和下一个步骤的长度\n",
    "        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)\n",
    "        model_kwargs = self._update_model_kwargs_for_generation(\n",
    "            outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder\n",
    "        )\n",
    "        unfinished_sequences = unfinished_sequences.mul(\n",
    "            next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)\n",
    "        )\n",
    "        if return_past_key_values:\n",
    "            yield input_ids, outputs.past_key_values\n",
    "        else:\n",
    "            yield input_ids\n",
    "        # 当每个句子完成时或超出最大长度时停止\n",
    "        if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):\n",
    "            break\n",
    "```\n",
    "\n",
    "- **作用**：流式生成模型输出，逐步返回生成的 token 以便实时处理。\n",
    "- **区别**：该函数通过生成器实现流式生成，每生成一个 token 就返回一次结果，适用于需要逐步展示生成结果的应用场景。\n",
    "- **联系**：`stream_chat` 函数依赖 `stream_generate` 来逐步生成响应，并在每次生成新的 token 后更新输入和模型参数。\n",
    "\n",
    "### 区别和联系总结\n",
    "\n",
    "1. **区别**：\n",
    "   - `chat` 函数：用于一次性生成完整响应，适合需要立即获得完整回答的应用场景。\n",
    "   - `stream_chat` 函数：用于流式生成响应，适合逐步展示生成结果的应用场景。\n",
    "   - `stream_generate` 函数：实现流式生成的核心逻辑，通过生成器逐步返回生成的 token。\n",
    "\n",
    "2. **联系**：\n",
    "   - `chat` 和 `stream_chat` 都是高层次接口，用户通过这些接口与模型交互。\n",
    "   - `chat` 函数调用 `generate` 函数实现一次性生成，而 `stream_chat` 函数调用 `stream_generate` 函数实现流式生成。\n",
    "   - `stream_generate` 函数使用了 `prepare_inputs_for_generation` 来处理输入数据，并通过 `_update_model_kwargs_for_generation` 更新模型参数，确保在生成过程中使用最新的数据。\n",
    "\n",
    "通过这三个函数的协调工作，模型能够实现高效且灵活的生成任务，满足不同应用场景的需求。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "62284871-67f4-47a2-941f-46befd2032b7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "device = \"cuda\"\n",
    "\n",
    "tokenizer = ChatGLM4Tokenizer.from_pretrained(\"THUDM/glm-4-9b-chat\", trust_remote_code=True)\n",
    "\n",
    "query = \"你好\"\n",
    "\n",
    "inputs = tokenizer.apply_chat_template([{\"role\": \"user\", \"content\": query}],\n",
    "                                       add_generation_prompt=True,\n",
    "                                       tokenize=True,\n",
    "                                       return_tensors=\"pt\",\n",
    "                                       return_dict=True\n",
    "                                       )\n",
    "\n",
    "inputs = inputs.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "0801f0ee-2d71-4eda-b2f5-5e00209cd0fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n",
      "The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "41bf96f6983b4492b2d95ae8f5eaa7ae",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "你好👋！有什么可以帮助你的吗？\n"
     ]
    }
   ],
   "source": [
    "model = ChatGLMForConditionalGeneration.from_pretrained(\n",
    "    \"THUDM/glm-4-9b-chat\",\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    low_cpu_mem_usage=True,\n",
    "    trust_remote_code=True,\n",
    "    load_in_4bit=True,\n",
    "    device_map='auto'\n",
    ").eval()#.to(device)\n",
    "\n",
    "gen_kwargs = {\"max_length\": 300, \"do_sample\": True, \"top_k\": 1}\n",
    "with torch.no_grad():\n",
    "    outputs = model.generate(**inputs, **gen_kwargs)\n",
    "    outputs = outputs[:, inputs['input_ids'].shape[1]:]\n",
    "    print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42e395e4-2f03-429b-b233-f612a7482ad5",
   "metadata": {},
   "source": [
    "耗时16s"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41c190f6-e495-49cc-8c20-70e4e8eceefd",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### 备注：其他测试草稿"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "923c9f19-430f-468d-8387-b73e45badeab",
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_chat_template(\n",
    "        self,\n",
    "        conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], \"Conversation\"],\n",
    "        add_generation_prompt: bool = False,\n",
    "        tokenize: bool = True,\n",
    "        padding: bool = False,\n",
    "        truncation: bool = False,\n",
    "        max_length: Optional[int] = None,\n",
    "        return_tensors: Optional[Union[str, TensorType]] = None,\n",
    "        return_dict: bool = False,\n",
    "        tokenizer_kwargs: Optional[Dict[str, Any]] = None,\n",
    "        add_special_tokens: bool = True,\n",
    "        **kwargs,\n",
    ") -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:\n",
    "\n",
    "    if return_dict and not tokenize:\n",
    "        raise ValueError(\n",
    "            \"`return_dict=True` is incompatible with `tokenize=False`, because there is no dict \"\n",
    "            \"of tokenizer outputs to return.\"\n",
    "        )\n",
    "\n",
    "    def handle_single_conversation(conversation):\n",
    "        input_ids = self.get_prefix_tokens() if add_special_tokens else []\n",
    "        input_message = \"[gMASK]<sop>\" if add_special_tokens else \"\"\n",
    "        for item in conversation:\n",
    "            if item.get(\"tools\"):\n",
    "                tools = item[\"tools\"]\n",
    "                content = \"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的，你的任务是针对用户的问题和要求提供适当的答复和支持。\"\n",
    "                for tool in tools:\n",
    "                    if tool[\"type\"] == \"function\":\n",
    "                        function = tool[\"function\"]\n",
    "                        content += f\"\\n\\n## {function['name']}\\n\\n{json.dumps(function, ensure_ascii=False, indent=4)}\"\n",
    "                        content += \"\\n在调用上述函数时，请使用 Json 格式表示调用的参数。\"\n",
    "                    elif tool[\"type\"] == \"python\":\n",
    "                        content += \"\\n\\n## python\\n\\n当你向 `python` 发送包含 Python 代码的消息时，该代码将会在一个有状态的 Jupyter notebook 环境中执行。\\n`python` 返回代码执行的输出，或在执行 60 秒后返回超时。\\n`/mnt/data` 将会持久化存储你的文件。在此会话中，`python` 无法访问互联网。不要使用 `python` 进行任何网络请求或者在线 API 调用，这些在线内容的访问将不会成功。\"\n",
    "                    elif tool[\"type\"] == \"simple_browser\":\n",
    "                        content += \"\\n\\n## simple_browser\\n\\n你可以使用 `simple_browser` 工具。该工具支持以下函数：\\n`search(query: str, recency_days: int)`：使用搜索引擎进行查询并显示结果，可以使用 `recency_days` 参数控制搜索内容的时效性。\\n`mclick(ids: list[int])`：获取一系列指定 id 的页面内容。每次调用时，须选择3-10个页面。选择多个角度的页面，同时尽可能选择可信任的信息来源。考虑到部分页面是无法加载的，你也可以多打开一些可能有用的页面而不用担心内容过多。\\n`open_url(url: str)`：打开指定的 URL。\\n\\n使用 `【{引用 id}†{引用文本}】` 来引用内容。\\n\\n操作步骤：1. 使用 `search` 来获得信息列表; 2. 使用 `mclick` 来获取指定 ID 页面的内容; 3. 根据获得的内容进行回复。在回复中应当引用信息来源。\\n 如果用户提供了 URL，也可以用 `open_url` 直接打开页面。\\n如果初次搜索结果没有找到合适的信息，也可以再次使用 `search` 进行搜索。\"\n",
    "                    elif tool[\"type\"] == \"cogview\":\n",
    "                        content += \"\\n\\n## cogview\\n\\n如果用户的请求中包含了对图像的描述，你可以使用 `cogview` 来生成图像并展示给用户。你需要向 `cogview` 发送图像描述，规则：\\n- 发送给 `cogview` 的消息必须使用英语。用户的中文描述必须完全翻译为英语。\\n- 应当尽可能详细地描述图像生成的需求，需求描述约 100 英文单词。\\n- 保持用户原始描述的意图。不要虚构内容或者没见过的人物。\\n- 如无特殊说明，所在地为中国，持有中国立场并遵循中国社会主义价值观。\"\n",
    "                    else:\n",
    "                        raise NotImplementedError(f\"Unknown tool type {tool['type']}\")\n",
    "                input = self.build_single_message(\"system\", \"\", content, tokenize=tokenize)\n",
    "                if tokenize:\n",
    "                    input_ids.extend(input)\n",
    "                else:\n",
    "                    input_message += input\n",
    "            if item[\"content\"]:\n",
    "                input = self.build_single_message(\n",
    "                    item[\"role\"],\n",
    "                    item.get(\"metadata\", \"\"),\n",
    "                    item[\"content\"],\n",
    "                    tokenize=tokenize\n",
    "                )\n",
    "                if tokenize:\n",
    "                    input_ids.extend(input)\n",
    "                else:\n",
    "                    input_message += input\n",
    "        if add_generation_prompt:\n",
    "            if tokenize:\n",
    "                input_ids.extend([self.convert_tokens_to_ids(\"<|assistant|>\")])\n",
    "            else:\n",
    "                input_message += \"<|assistant|>\"\n",
    "\n",
    "        return input_ids if tokenize else input_message\n",
    "\n",
    "    # Main logic to handle different conversation formats\n",
    "    if isinstance(conversation, list) and all(isinstance(i, dict) for i in conversation):\n",
    "        result = handle_single_conversation(conversation)\n",
    "    elif isinstance(conversation, list) and all(isinstance(i, list) for i in conversation):\n",
    "        result = [handle_single_conversation(c) for c in conversation]\n",
    "    elif hasattr(conversation, \"messages\"):\n",
    "        result = handle_single_conversation(conversation.messages)\n",
    "    else:\n",
    "        raise ValueError(\"Invalid conversation format\")\n",
    "\n",
    "    if tokenize:\n",
    "        output = self.batch_encode_plus(\n",
    "            [result] if isinstance(result[0], int) else result,\n",
    "            padding=padding,\n",
    "            truncation=truncation,\n",
    "            max_length=max_length,\n",
    "            return_tensors=return_tensors,\n",
    "            is_split_into_words=True,\n",
    "            add_special_tokens=False\n",
    "        )\n",
    "        if return_dict:\n",
    "            return output\n",
    "        else:\n",
    "            return output[\"input_ids\"]\n",
    "    else:\n",
    "        return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ece39da8-2b28-4415-b226-0919f2beb6de",
   "metadata": {},
   "outputs": [],
   "source": [
    "def handle_single_conversation(conversation):\n",
    "    input_ids = self.get_prefix_tokens() if add_special_tokens else []\n",
    "    input_message = \"[gMASK]<sop>\" if add_special_tokens else \"\"\n",
    "    for item in conversation:\n",
    "        if item.get(\"tools\"):\n",
    "            tools = item[\"tools\"]\n",
    "            content = \"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的，你的任务是针对用户的问题和要求提供适当的答复和支持。\"\n",
    "            for tool in tools:\n",
    "                if tool[\"type\"] == \"function\":\n",
    "                    function = tool[\"function\"]\n",
    "                    content += f\"\\n\\n## {function['name']}\\n\\n{json.dumps(function, ensure_ascii=False, indent=4)}\"\n",
    "                    content += \"\\n在调用上述函数时，请使用 Json 格式表示调用的参数。\"\n",
    "                elif tool[\"type\"] == \"python\":\n",
    "                    content += \"\\n\\n## python\\n\\n当你向 `python` 发送包含 Python 代码的消息时，该代码将会在一个有状态的 Jupyter notebook 环境中执行。\\n`python` 返回代码执行的输出，或在执行 60 秒后返回超时。\\n`/mnt/data` 将会持久化存储你的文件。在此会话中，`python` 无法访问互联网。不要使用 `python` 进行任何网络请求或者在线 API 调用，这些在线内容的访问将不会成功。\"\n",
    "                elif tool[\"type\"] == \"simple_browser\":\n",
    "                    content += \"\\n\\n## simple_browser\\n\\n你可以使用 `simple_browser` 工具。该工具支持以下函数：\\n`search(query: str, recency_days: int)`：使用搜索引擎进行查询并显示结果，可以使用 `recency_days` 参数控制搜索内容的时效性。\\n`mclick(ids: list[int])`：获取一系列指定 id 的页面内容。每次调用时，须选择3-10个页面。选择多个角度的页面，同时尽可能选择可信任的信息来源。考虑到部分页面是无法加载的，你也可以多打开一些可能有用的页面而不用担心内容过多。\\n`open_url(url: str)`：打开指定的 URL。\\n\\n使用 `【{引用 id}†{引用文本}】` 来引用内容。\\n\\n操作步骤：1. 使用 `search` 来获得信息列表; 2. 使用 `mclick` 来获取指定 ID 页面的内容; 3. 根据获得的内容进行回复。在回复中应当引用信息来源。\\n 如果用户提供了 URL，也可以用 `open_url` 直接打开页面。\\n如果初次搜索结果没有找到合适的信息，也可以再次使用 `search` 进行搜索。\"\n",
    "                elif tool[\"type\"] == \"cogview\":\n",
    "                    content += \"\\n\\n## cogview\\n\\n如果用户的请求中包含了对图像的描述，你可以使用 `cogview` 来生成图像并展示给用户。你需要向 `cogview` 发送图像描述，规则：\\n- 发送给 `cogview` 的消息必须使用英语。用户的中文描述必须完全翻译为英语。\\n- 应当尽可能详细地描述图像生成的需求，需求描述约 100 英文单词。\\n- 保持用户原始描述的意图。不要虚构内容或者没见过的人物。\\n- 如无特殊说明，所在地为中国，持有中国立场并遵循中国社会主义价值观。\"\n",
    "                else:\n",
    "                    raise NotImplementedError(f\"Unknown tool type {tool['type']}\")\n",
    "            input = self.build_single_message(\"system\", \"\", content, tokenize=tokenize)\n",
    "            if tokenize:\n",
    "                input_ids.extend(input)\n",
    "            else:\n",
    "                input_message += input\n",
    "        if item[\"content\"]:\n",
    "            input = self.build_single_message(\n",
    "                item[\"role\"],\n",
    "                item.get(\"metadata\", \"\"),\n",
    "                item[\"content\"],\n",
    "                tokenize=tokenize\n",
    "            )\n",
    "            if tokenize:\n",
    "                input_ids.extend(input)\n",
    "            else:\n",
    "                input_message += input\n",
    "    if add_generation_prompt:\n",
    "        if tokenize:\n",
    "            input_ids.extend([self.convert_tokens_to_ids(\"<|assistant|>\")])\n",
    "        else:\n",
    "            input_message += \"<|assistant|>\"\n",
    "\n",
    "    return input_ids if tokenize else input_message"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69cc9551-ad44-4ecd-b7a6-00267fbcc186",
   "metadata": {},
   "outputs": [],
   "source": [
    "convert_tokens_to_ids(\"<|assistant|>\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "ef2589e3-5f81-4c93-bb2d-a240b6041411",
   "metadata": {},
   "outputs": [],
   "source": [
    "def handle_single_conversation(conversation):\n",
    "    input_ids = self.get_prefix_tokens() if add_special_tokens else []\n",
    "    input_message = \"[gMASK]<sop>\" if add_special_tokens else \"\"\n",
    "    for item in conversation:\n",
    "        if item.get(\"tools\"):\n",
    "            tools = item[\"tools\"]\n",
    "            content = \"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的，你的任务是针对用户的问题和要求提供适当的答复和支持。\"\n",
    "            for tool in tools:\n",
    "                if tool[\"type\"] == \"function\":\n",
    "                    function = tool[\"function\"]\n",
    "                    content += f\"\\n\\n## {function['name']}\\n\\n{json.dumps(function, ensure_ascii=False, indent=4)}\"\n",
    "                    content += \"\\n在调用上述函数时，请使用 Json 格式表示调用的参数。\"\n",
    "                elif tool[\"type\"] == \"python\":\n",
    "                    content += \"\\n\\n## python\\n\\n当你向 `python` 发送包含 Python 代码的消息时，该代码将会在一个有状态的 Jupyter notebook 环境中执行。\\n`python` 返回代码执行的输出，或在执行 60 秒后返回超时。\\n`/mnt/data` 将会持久化存储你的文件。在此会话中，`python` 无法访问互联网。不要使用 `python` 进行任何网络请求或者在线 API 调用，这些在线内容的访问将不会成功。\"\n",
    "                elif tool[\"type\"] == \"simple_browser\":\n",
    "                    content += \"\\n\\n## simple_browser\\n\\n你可以使用 `simple_browser` 工具。该工具支持以下函数：\\n`search(query: str, recency_days: int)`：使用搜索引擎进行查询并显示结果，可以使用 `recency_days` 参数控制搜索内容的时效性。\\n`mclick(ids: list[int])`：获取一系列指定 id 的页面内容。每次调用时，须选择3-10个页面。选择多个角度的页面，同时尽可能选择可信任的信息来源。考虑到部分页面是无法加载的，你也可以多打开一些可能有用的页面而不用担心内容过多。\\n`open_url(url: str)`：打开指定的 URL。\\n\\n使用 `【{引用 id}†{引用文本}】` 来引用内容。\\n\\n操作步骤：1. 使用 `search` 来获得信息列表; 2. 使用 `mclick` 来获取指定 ID 页面的内容; 3. 根据获得的内容进行回复。在回复中应当引用信息来源。\\n 如果用户提供了 URL，也可以用 `open_url` 直接打开页面。\\n如果初次搜索结果没有找到合适的信息，也可以再次使用 `search` 进行搜索。\"\n",
    "                elif tool[\"type\"] == \"cogview\":\n",
    "                    content += \"\\n\\n## cogview\\n\\n如果用户的请求中包含了对图像的描述，你可以使用 `cogview` 来生成图像并展示给用户。你需要向 `cogview` 发送图像描述，规则：\\n- 发送给 `cogview` 的消息必须使用英语。用户的中文描述必须完全翻译为英语。\\n- 应当尽可能详细地描述图像生成的需求，需求描述约 100 英文单词。\\n- 保持用户原始描述的意图。不要虚构内容或者没见过的人物。\\n- 如无特殊说明，所在地为中国，持有中国立场并遵循中国社会主义价值观。\"\n",
    "                else:\n",
    "                    raise NotImplementedError(f\"Unknown tool type {tool['type']}\")\n",
    "            input = self.build_single_message(\"system\", \"\", content, tokenize=tokenize)\n",
    "            if tokenize:\n",
    "                input_ids.extend(input)\n",
    "            else:\n",
    "                input_message += input\n",
    "        if item[\"content\"]:\n",
    "            input = self.build_single_message(\n",
    "                item[\"role\"],\n",
    "                item.get(\"metadata\", \"\"),\n",
    "                item[\"content\"],\n",
    "                tokenize=tokenize\n",
    "            )\n",
    "            if tokenize:\n",
    "                input_ids.extend(input)\n",
    "            else:\n",
    "                input_message += input\n",
    "    if add_generation_prompt:\n",
    "        if tokenize:\n",
    "            input_ids.extend([self.convert_tokens_to_ids(\"<|assistant|>\")])\n",
    "        else:\n",
    "            input_message += \"<|assistant|>\"\n",
    "        # if tokenize:\n",
    "        #     input_ids.extend([self.convert_tokens_to_ids(\"[gMASK]\")])  # 使用特殊标记代替空字符串\n",
    "        # else:\n",
    "        #     input_message += \"[gMASK]\"\n",
    "\n",
    "    return input_ids if tokenize else input_message"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "a8400577-cfae-44ab-a1a9-831177ec24c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_chat_template(\n",
    "        self,\n",
    "        conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], \"Conversation\"],\n",
    "        add_generation_prompt: bool = False,\n",
    "        tokenize: bool = True,\n",
    "        padding: bool = False,\n",
    "        truncation: bool = False,\n",
    "        max_length: Optional[int] = None,\n",
    "        return_tensors: Optional[Union[str, TensorType]] = None,\n",
    "        return_dict: bool = False,\n",
    "        tokenizer_kwargs: Optional[Dict[str, Any]] = None,\n",
    "        add_special_tokens: bool = True,\n",
    "        **kwargs,\n",
    ") -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:\n",
    "\n",
    "    if return_dict and not tokenize:\n",
    "        raise ValueError(\n",
    "            \"`return_dict=True` is incompatible with `tokenize=False`, because there is no dict \"\n",
    "            \"of tokenizer outputs to return.\"\n",
    "        )\n",
    "\n",
    "    def handle_single_conversation(conversation):\n",
    "        input_ids = self.get_prefix_tokens() if add_special_tokens else []\n",
    "        input_message = \"[gMASK]<sop>\" if add_special_tokens else \"\"\n",
    "        for item in conversation:\n",
    "            if item.get(\"tools\"):\n",
    "                tools = item[\"tools\"]\n",
    "                content = \"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的，你的任务是针对用户的问题和要求提供适当的答复和支持。\"\n",
    "                for tool in tools:\n",
    "                    if tool[\"type\"] == \"function\":\n",
    "                        function = tool[\"function\"]\n",
    "                        content += f\"\\n\\n## {function['name']}\\n\\n{json.dumps(function, ensure_ascii=False, indent=4)}\"\n",
    "                        content += \"\\n在调用上述函数时，请使用 Json 格式表示调用的参数。\"\n",
    "                    elif tool[\"type\"] == \"python\":\n",
    "                        content += \"\\n\\n## python\\n\\n当你向 `python` 发送包含 Python 代码的消息时，该代码将会在一个有状态的 Jupyter notebook 环境中执行。\\n`python` 返回代码执行的输出，或在执行 60 秒后返回超时。\\n`/mnt/data` 将会持久化存储你的文件。在此会话中，`python` 无法访问互联网。不要使用 `python` 进行任何网络请求或者在线 API 调用，这些在线内容的访问将不会成功。\"\n",
    "                    elif tool[\"type\"] == \"simple_browser\":\n",
    "                        content += \"\\n\\n## simple_browser\\n\\n你可以使用 `simple_browser` 工具。该工具支持以下函数：\\n`search(query: str, recency_days: int)`：使用搜索引擎进行查询并显示结果，可以使用 `recency_days` 参数控制搜索内容的时效性。\\n`mclick(ids: list[int])`：获取一系列指定 id 的页面内容。每次调用时，须选择3-10个页面。选择多个角度的页面，同时尽可能选择可信任的信息来源。考虑到部分页面是无法加载的，你也可以多打开一些可能有用的页面而不用担心内容过多。\\n`open_url(url: str)`：打开指定的 URL。\\n\\n使用 `【{引用 id}†{引用文本}】` 来引用内容。\\n\\n操作步骤：1. 使用 `search` 来获得信息列表; 2. 使用 `mclick` 来获取指定 ID 页面的内容; 3. 根据获得的内容进行回复。在回复中应当引用信息来源。\\n 如果用户提供了 URL，也可以用 `open_url` 直接打开页面。\\n如果初次搜索结果没有找到合适的信息，也可以再次使用 `search` 进行搜索。\"\n",
    "                    elif tool[\"type\"] == \"cogview\":\n",
    "                        content += \"\\n\\n## cogview\\n\\n如果用户的请求中包含了对图像的描述，你可以使用 `cogview` 来生成图像并展示给用户。你需要向 `cogview` 发送图像描述，规则：\\n- 发送给 `cogview` 的消息必须使用英语。用户的中文描述必须完全翻译为英语。\\n- 应当尽可能详细地描述图像生成的需求，需求描述约 100 英文单词。\\n- 保持用户原始描述的意图。不要虚构内容或者没见过的人物。\\n- 如无特殊说明，所在地为中国，持有中国立场并遵循中国社会主义价值观。\"\n",
    "                    else:\n",
    "                        raise NotImplementedError(f\"Unknown tool type {tool['type']}\")\n",
    "                input = self.build_single_message(\"system\", \"\", content, tokenize=tokenize)\n",
    "                if tokenize:\n",
    "                    input_ids.extend(input)\n",
    "                else:\n",
    "                    input_message += input\n",
    "            if item[\"content\"]:\n",
    "                input = self.build_single_message(\n",
    "                    item[\"role\"],\n",
    "                    item.get(\"metadata\", \"\"),\n",
    "                    item[\"content\"],\n",
    "                    tokenize=tokenize\n",
    "                )\n",
    "                if tokenize:\n",
    "                    input_ids.extend(input)\n",
    "                else:\n",
    "                    input_message += input\n",
    "        if add_generation_prompt:\n",
    "            if tokenize:\n",
    "                input_ids.extend([self.convert_tokens_to_ids(\"<|assistant|>\")])\n",
    "            else:\n",
    "                input_message += \"<|assistant|>\"\n",
    "            # if tokenize:\n",
    "            #     input_ids.extend([self.convert_tokens_to_ids(\"[gMASK]\")])  # 使用特殊标记代替空字符串\n",
    "            # else:\n",
    "            #     input_message += \"[gMASK]\"\n",
    "\n",
    "        return input_ids if tokenize else input_message\n",
    "\n",
    "    # 处理不同会话格式的主逻辑\n",
    "    if isinstance(conversation, list) and all(isinstance(i, dict) for i in conversation):\n",
    "        result = handle_single_conversation(conversation)\n",
    "    elif isinstance(conversation, list) and all(isinstance(i, list) for i in conversation):\n",
    "        result = [handle_single_conversation(c) for c in conversation]\n",
    "    elif hasattr(conversation, \"messages\"):\n",
    "        result = handle_single_conversation(conversation.messages)\n",
    "    else:\n",
    "        raise ValueError(\"Invalid conversation format\")\n",
    "\n",
    "    if tokenize:\n",
    "        output = self.batch_encode_plus(\n",
    "            [result] if isinstance(result[0], int) else result,\n",
    "            padding=padding,\n",
    "            truncation=truncation,\n",
    "            max_length=max_length,\n",
    "            return_tensors=return_tensors,\n",
    "            is_split_into_words=True,\n",
    "            add_special_tokens=False\n",
    "        )\n",
    "        if return_dict:\n",
    "            return output\n",
    "        else:\n",
    "            return output[\"input_ids\"]\n",
    "    else:\n",
    "        return result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "38eb42ba-0628-4f9d-bf29-410607552950",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
      "Requirement already satisfied: regex in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (2023.10.3)\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install regex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "80613542-7de6-458e-9b93-8dc022fc2801",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "regex.Regex(\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\\\r\\\\n\\\\p{L}\\\\p{N}]?\\\\p{L}+|\\\\p{N}{1,3}| ?[^\\\\s\\\\p{L}\\\\p{N}]+[\\\\r\\\\n]*|\\\\s*[\\\\r\\\\n]+|\\\\s+(?!\\\\S)|\\\\s+\", flags=regex.V0)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pat_str = \"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\\\r\\\\n\\\\p{L}\\\\p{N}]?\\\\p{L}+|\\\\p{N}{1,3}| ?[^\\\\s\\\\p{L}\\\\p{N}]+[\\\\r\\\\n]*|\\\\s*[\\\\r\\\\n]+|\\\\s+(?!\\\\S)|\\\\s+\"\n",
    "regex.compile(pat_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "e08863f0-9acd-458b-9af8-4ba37f1d7c21",
   "metadata": {},
   "outputs": [],
   "source": [
    "import regex\n",
    "\n",
    "pat_str = r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\"\n",
    "pattern = regex.compile(pat_str)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "kewei-ai",
   "language": "python",
   "name": "kewei-ai"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
